Compare commits

..

1 Commits

Author SHA1 Message Date
Antonio Scandurra
6a75936419 WIP 2024-07-20 12:45:18 +02:00
535 changed files with 12904 additions and 22097 deletions

View File

@@ -12,7 +12,3 @@ rustflags = ["-C", "link-arg=-fuse-ld=mold"]
[target.aarch64-unknown-linux-gnu]
linker = "clang"
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
# This cfg will reduce the size of `windows::core::Error` from 16 bytes to 4 bytes
[target.'cfg(target_os = "windows")']
rustflags = ["--cfg", "windows_slim_errors"]

View File

@@ -10,7 +10,7 @@ runs:
cargo install cargo-nextest
- name: Install Node
uses: actions/setup-node@1e60f620b9541d16bece96c5465dc8ee9832be0b # v4
uses: actions/setup-node@v4
with:
node-version: "18"

View File

@@ -19,7 +19,7 @@ jobs:
- test
steps:
- name: Checkout code
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v2
with:
ref: ${{ github.event.inputs.branch }}
ssh-key: ${{ secrets.ZED_BOT_DEPLOY_KEY }}

View File

@@ -30,7 +30,7 @@ jobs:
- test
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
fetch-depth: 0
@@ -40,15 +40,10 @@ jobs:
- name: Check spelling
run: |
if ! cargo install --list | grep "typos-cli v$TYPOS_CLI_VERSION" > /dev/null; then
echo "Installing typos-cli@$TYPOS_CLI_VERSION..."
cargo install "typos-cli@$TYPOS_CLI_VERSION"
else
echo "typos-cli@$TYPOS_CLI_VERSION is already installed."
if ! which typos > /dev/null; then
cargo install typos-cli
fi
typos
env:
TYPOS_CLI_VERSION: "1.23.3"
- name: Run style checks
uses: ./.github/actions/check_style
@@ -90,7 +85,7 @@ jobs:
- test
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -117,7 +112,7 @@ jobs:
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -137,18 +132,17 @@ jobs:
runs-on: hosted-windows-1
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
uses: swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
- name: cargo clippy
# Windows can't run shell scripts, so we need to use `cargo xtask`.
run: cargo xtask clippy
run: ./script/clippy
- name: Build Zed
run: cargo build -p zed
@@ -171,12 +165,12 @@ jobs:
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
steps:
- name: Install Node
uses: actions/setup-node@1e60f620b9541d16bece96c5465dc8ee9832be0b # v4
uses: actions/setup-node@v4
with:
node-version: "18"
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
# We need to fetch more than one commit so that `script/draft-release-notes`
# is able to diff between the current and previous tag.
@@ -231,34 +225,32 @@ jobs:
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
- name: Upload app bundle (universal) to workflow run if main branch or specific label
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
uses: actions/upload-artifact@v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
path: target/release/Zed.dmg
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
uses: actions/upload-artifact@v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
uses: actions/upload-artifact@v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
path: target/x86_64-apple-darwin/release/Zed-x86_64.dmg
- uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
- uses: softprops/action-gh-release@v1
name: Upload app bundle to release
if: ${{ env.RELEASE_CHANNEL == 'preview' || env.RELEASE_CHANNEL == 'stable' }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-macos-x86_64.gz
target/zed-remote-server-macos-aarch64.gz
target/aarch64-apple-darwin/release/Zed-aarch64.dmg
target/x86_64-apple-darwin/release/Zed-x86_64.dmg
target/release/Zed.dmg
@@ -281,7 +273,7 @@ jobs:
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -319,20 +311,18 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
uses: actions/upload-artifact@v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload app bundle to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
uses: softprops/action-gh-release@v1
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-linux-x86_64.gz
target/release/zed-linux-x86_64.tar.gz
files: target/release/zed-linux-x86_64.tar.gz
body: ""
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -348,11 +338,11 @@ jobs:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
- name: "Setup jq"
uses: dcarbone/install-jq-action@8867ddb4788346d7c22b72ea2e2ffe4d514c7bcb # v2
uses: dcarbone/install-jq-action@v2
- name: Set up Clang
run: |
@@ -360,7 +350,7 @@ jobs:
sudo apt-get install -y llvm-10 clang-10 build-essential cmake pkg-config libasound2-dev libfontconfig-dev libwayland-dev libxkbcommon-x11-dev libssl-dev libsqlite3-dev libzstd-dev libvulkan1 libgit2-dev
echo "/usr/lib/llvm-10/bin" >> $GITHUB_PATH
- uses: rui314/setup-mold@2e332a0b602c2fc65d2d3995941b1b29a5f554a0 # v1
- uses: rui314/setup-mold@v1
with:
mold-version: 2.32.0
@@ -403,21 +393,19 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
uses: actions/upload-artifact@v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
path: target/release/zed-*.tar.gz
- name: Upload app bundle to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
uses: softprops/action-gh-release@v1
if: ${{ env.RELEASE_CHANNEL == 'preview' || env.RELEASE_CHANNEL == 'stable' }}
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-linux-aarch64.gz
target/release/zed-linux-aarch64.tar.gz
files: target/release/zed-linux-aarch64.tar.gz
body: ""
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -14,14 +14,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
- uses: actions/checkout@v4
- uses: pnpm/action-setup@a3252b78c470c02df07e9d59298aecedc3ccdd6d # v3
- uses: pnpm/action-setup@v3
with:
version: 9
- name: Setup Node
uses: actions/setup-node@1e60f620b9541d16bece96c5465dc8ee9832be0b # v4
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "pnpm"

View File

@@ -12,12 +12,12 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
- name: Setup mdBook
uses: peaceiris/actions-mdbook@ee69d230fe19748b7abf22df32acaa93833fad08 # v2
uses: peaceiris/actions-mdbook@v2
with:
mdbook-version: "0.4.37"
@@ -28,28 +28,28 @@ jobs:
mdbook build ./docs --dest-dir=../target/deploy/docs/
- name: Deploy Docs
uses: cloudflare/wrangler-action@f84a562284fc78278ff9052435d9526f9c718361 # v3
uses: cloudflare/wrangler-action@v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: pages deploy target/deploy --project-name=docs
- name: Deploy Install
uses: cloudflare/wrangler-action@f84a562284fc78278ff9052435d9526f9c718361 # v3
uses: cloudflare/wrangler-action@v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
- name: Deploy Docs Workers
uses: cloudflare/wrangler-action@f84a562284fc78278ff9052435d9526f9c718361 # v3
uses: cloudflare/wrangler-action@v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
command: deploy .cloudflare/docs-proxy/src/worker.js
- name: Deploy Install Workers
uses: cloudflare/wrangler-action@f84a562284fc78278ff9052435d9526f9c718361 # v3
uses: cloudflare/wrangler-action@v3
with:
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}

View File

@@ -18,7 +18,7 @@ jobs:
- test
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
fetch-depth: 0
@@ -37,7 +37,7 @@ jobs:
needs: style
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
fetch-depth: 0
@@ -71,7 +71,7 @@ jobs:
run: doctl registry login
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false

View File

@@ -1,24 +0,0 @@
name: Docs
on:
pull_request:
paths:
- "docs/**"
push:
branches:
- main
jobs:
check_formatting:
name: "Check formatting"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
- uses: pnpm/action-setup@a3252b78c470c02df07e9d59298aecedc3ccdd6d # v3
with:
version: 9
- run: pnpm dlx prettier . --check
working-directory: ./docs

View File

@@ -16,12 +16,12 @@ jobs:
- ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
uses: swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}

View File

@@ -23,12 +23,12 @@ jobs:
- randomized-tests
steps:
- name: Install Node
uses: actions/setup-node@1e60f620b9541d16bece96c5465dc8ee9832be0b # v4
uses: actions/setup-node@v4
with:
node-version: "18"
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false

View File

@@ -16,7 +16,7 @@ jobs:
fi
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@67b1b814955634208b103cff064be3cb1c7a19be # v1.3.0
uses: 2428392/gh-truncate-string-action@v1.3.0
id: get-content
with:
stringToTruncate: |
@@ -26,7 +26,7 @@ jobs:
maxLength: 2000
truncationSymbol: "..."
- name: Discord Webhook Action
uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0
uses: tsickert/discord-webhook@v5.3.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}

View File

@@ -23,7 +23,7 @@ jobs:
- test
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
fetch-depth: 0
@@ -33,7 +33,6 @@ jobs:
- name: Run clippy
run: ./script/clippy
tests:
timeout-minutes: 60
name: Run tests
@@ -44,7 +43,7 @@ jobs:
needs: style
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -69,12 +68,12 @@ jobs:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
steps:
- name: Install Node
uses: actions/setup-node@1e60f620b9541d16bece96c5465dc8ee9832be0b # v4
uses: actions/setup-node@v4
with:
node-version: "18"
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -108,7 +107,7 @@ jobs:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
@@ -133,6 +132,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for ARM
if: github.repository_owner == 'zed-industries'
runs-on:
- self-hosted
- hosted-linux-arm-1
needs: tests
env:
@@ -141,12 +141,12 @@ jobs:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
uses: actions/checkout@v4
with:
clean: false
- name: "Setup jq"
uses: dcarbone/install-jq-action@8867ddb4788346d7c22b72ea2e2ffe4d514c7bcb # v2
uses: dcarbone/install-jq-action@v2
- name: Set up Clang
run: |
@@ -154,7 +154,7 @@ jobs:
sudo apt-get install -y llvm-10 clang-10 build-essential cmake pkg-config libasound2-dev libfontconfig-dev libwayland-dev libxkbcommon-x11-dev libssl-dev libsqlite3-dev libzstd-dev libvulkan1 libgit2-dev
echo "/usr/lib/llvm-10/bin" >> $GITHUB_PATH
- uses: rui314/setup-mold@2e332a0b602c2fc65d2d3995941b1b29a5f554a0 # v1
- uses: rui314/setup-mold@v1
with:
mold-version: 2.32.0

View File

@@ -8,8 +8,8 @@ jobs:
runs-on: ubuntu-latest
if: github.repository_owner == 'zed-industries'
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
architecture: "x64"

View File

@@ -8,8 +8,8 @@ jobs:
runs-on: ubuntu-latest
if: github.repository_owner == 'zed-industries'
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
architecture: "x64"

View File

@@ -26,10 +26,6 @@
"tab_size": 2,
"formatter": "prettier"
},
"CSS": {
"tab_size": 2,
"formatter": "prettier"
},
"Rust": {
"tasks": {
"variables": {

1957
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,11 @@
[workspace]
resolver = "2"
members = [
"crates/activity_indicator",
"crates/anthropic",
"crates/assets",
"crates/assistant",
"crates/assistant_slash_command",
"crates/assistant_tooling",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@@ -19,7 +19,6 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
"crates/completion",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@@ -44,14 +43,13 @@ members = [
"crates/gpui_macros",
"crates/headless",
"crates/html_to_markdown",
"crates/http_client",
"crates/http",
"crates/image_viewer",
"crates/indexed_docs",
"crates/inline_completion_button",
"crates/install_cli",
"crates/journal",
"crates/language",
"crates/language_model",
"crates/language_selector",
"crates/language_tools",
"crates/languages",
@@ -81,8 +79,6 @@ members = [
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/release_channel",
"crates/remote",
"crates/remote_server",
"crates/repl",
"crates/rich_text",
"crates/rope",
@@ -90,9 +86,7 @@ members = [
"crates/search",
"crates/semantic_index",
"crates/semantic_version",
"crates/session",
"crates/settings",
"crates/settings_ui",
"crates/snippet",
"crates/snippet_provider",
"crates/sqlez",
@@ -125,10 +119,6 @@ members = [
"crates/zed",
"crates/zed_actions",
#
# Extensions
#
"extensions/astro",
"extensions/clojure",
"extensions/csharp",
@@ -147,7 +137,6 @@ members = [
"extensions/php",
"extensions/prisma",
"extensions/purescript",
"extensions/ruff",
"extensions/ruby",
"extensions/snippets",
"extensions/svelte",
@@ -158,25 +147,19 @@ members = [
"extensions/vue",
"extensions/zig",
#
# Tooling
#
"tooling/xtask",
]
default-members = ["crates/zed"]
resolver = "2"
[workspace.dependencies]
#
# Workspace member crates
#
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
breadcrumbs = { path = "crates/breadcrumbs" }
@@ -190,7 +173,6 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@@ -213,14 +195,13 @@ gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
headless = { path = "crates/headless" }
html_to_markdown = { path = "crates/html_to_markdown" }
http_client = { path = "crates/http_client" }
http = { path = "crates/http" }
image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_model = { path = "crates/language_model" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }
@@ -249,10 +230,7 @@ project_symbols = { path = "crates/project_symbols" }
proto = { path = "crates/proto" }
quick_action_bar = { path = "crates/quick_action_bar" }
recent_projects = { path = "crates/recent_projects" }
refineable = { path = "crates/refineable" }
release_channel = { path = "crates/release_channel" }
remote = { path = "crates/remote" }
remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" }
@@ -260,9 +238,7 @@ rpc = { path = "crates/rpc" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }
session = { path = "crates/session" }
settings = { path = "crates/settings" }
settings_ui = { path = "crates/settings_ui" }
snippet = { path = "crates/snippet" }
snippet_provider = { path = "crates/snippet_provider" }
sqlez = { path = "crates/sqlez" }
@@ -295,44 +271,37 @@ worktree = { path = "crates/worktree" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
#
# External crates
#
aho-corasick = "1.1"
alacritty_terminal = "0.23"
any_vec = "0.14"
anyhow = "1.0.86"
any_vec = "0.13"
anyhow = "1.0.57"
ashpd = "0.9.1"
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
async-dispatcher = "0.1"
async-dispatcher = { version = "0.1" }
async-fs = "1.6"
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
async-recursion = "1.0.0"
async-tar = "0.4.2"
async-trait = "0.1"
async-tungstenite = "0.23"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
base64 = "0.22"
base64 = "0.13"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
blade-macros = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
blade-util = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
cargo_metadata = "0.18"
blade-graphics = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
blade-macros = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
blade-util = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
cap-std = "3.0"
cargo_toml = "0.20"
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.4", features = ["derive"] }
clickhouse = "0.11.6"
clickhouse = { version = "0.11.6" }
cocoa = "0.25"
core-foundation = "0.9.3"
core-foundation = { version = "0.9.3" }
core-foundation-sys = "0.8.6"
ctor = "0.2.6"
dashmap = "6.0"
dashmap = "5.5.3"
derive_more = "0.99.17"
dirs = "4.0"
emojis = "0.6.1"
env_logger = "0.11"
env_logger = "0.9"
exec = "0.3.1"
fork = "0.1.23"
futures = "0.3"
@@ -346,17 +315,16 @@ html5ever = "0.27.0"
ignore = "0.4.22"
image = "0.25.1"
indexmap = { version = "1.6.2", features = ["serde"] }
indoc = "2"
indoc = "1"
# We explicitly disable http2 support in isahc.
isahc = { version = "1.7.2", default-features = false, features = [
"text-decoding",
] }
itertools = "0.11.0"
jsonwebtoken = "9.3"
lazy_static = "1.4.0"
libc = "0.2"
linkify = "0.10.0"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
markup5ever_rcdom = "0.3.0"
nanoid = "0.4"
nix = "0.28"
@@ -374,16 +342,15 @@ prost-build = "0.9"
prost-types = "0.9"
pulldown-cmark = { version = "0.10.0", default-features = false }
rand = "0.8.5"
refineable = { path = "./crates/refineable" }
regex = "1.5"
repair_json = "0.1.0"
rsa = "0.9.6"
runtimelib = { version = "0.12", default-features = false, features = [
"async-dispatcher-runtime",
] }
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
rustc-demangle = "0.1.23"
rust-embed = { version = "8.4", features = ["include-exclude"] }
schemars = {version = "0.8", features = ["impl_json_schema"]}
schemars = "0.8"
semver = "1.0"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
@@ -403,7 +370,6 @@ smallvec = { version = "1.6", features = ["union"] }
smol = "1.2"
strum = { version = "0.25.0", features = ["derive"] }
subtle = "2.5.0"
sys-locale = "0.3.1"
sysinfo = "0.30.7"
tempfile = "3.9.0"
thiserror = "1.0.29"
@@ -415,32 +381,32 @@ time = { version = "0.3", features = [
"serde-well-known",
"formatting",
] }
tiny_http = "0.8"
toml = "0.8"
tokio = { version = "1", features = ["full"] }
tower-http = "0.4.4"
tree-sitter = { version = "0.22", features = ["wasm"] }
tree-sitter-bash = "0.21"
tree-sitter-c = "0.21"
tree-sitter-cpp = "0.22"
tree-sitter-css = "0.21"
tree-sitter-elixir = "0.2"
tree-sitter = { version = "0.20", features = ["wasm"] }
tree-sitter-bash = "0.20.5"
tree-sitter-c = "0.20.1"
tree-sitter-cpp = "0.20.5"
tree-sitter-css = "0.20"
tree-sitter-elixir = "0.1.1"
tree-sitter-embedded-template = "0.20.0"
tree-sitter-go = "0.21"
tree-sitter-go-mod = { git = "https://github.com/SomeoneToIgnore/tree-sitter-go-mod", rev = "8c1f54f12bb4c846336b634bc817645d6f35d641", package = "tree-sitter-gomod" }
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work", rev = "dcbabff454703c3a4bc98a23cf8778d4be46fd22" }
tree-sitter-heex = { git = "https://github.com/phoenixframework/tree-sitter-heex", rev = "6dd0303acf7138dd2b9b432a229e16539581c701" }
tree-sitter-html = "0.20"
tree-sitter-jsdoc = "0.21"
tree-sitter-json = "0.21"
tree-sitter-md = { git = "https://github.com/zed-industries/tree-sitter-markdown", rev = "e3855e37f8f2c71aa7513c18a9c95fb7461b1b10" }
protols-tree-sitter-proto = "0.2"
tree-sitter-python = "0.21"
tree-sitter-regex = "0.21"
tree-sitter-ruby = "0.21"
tree-sitter-rust = "0.21"
tree-sitter-typescript = "0.21"
tree-sitter-yaml = "0.6"
tree-sitter-go = { git = "https://github.com/tree-sitter/tree-sitter-go", rev = "b82ab803d887002a0af11f6ce63d72884580bf33" }
tree-sitter-gomod = "1.0.1"
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work" }
rustc-demangle = "0.1.23"
tree-sitter-heex = { git = "https://github.com/phoenixframework/tree-sitter-heex", rev = "2e1348c3cf2c9323e87c2744796cf3f3868aa82a" }
tree-sitter-html = "0.19.0"
tree-sitter-jsdoc = { git = "https://github.com/tree-sitter/tree-sitter-jsdoc", rev = "6a6cf9e7341af32d8e2b2e24a37fbfebefc3dc55" }
tree-sitter-json = "0.20.2"
tree-sitter-markdown = { git = "https://github.com/MDeiml/tree-sitter-markdown", rev = "330ecab87a3e3a7211ac69bbadc19eabecdb1cca" }
tree-sitter-proto = { git = "https://github.com/rewinfrey/tree-sitter-proto", rev = "36d54f288aee112f13a67b550ad32634d0c2cb52" }
tree-sitter-python = "0.20.2"
tree-sitter-regex = "0.20.0"
tree-sitter-ruby = "0.20.0"
tree-sitter-rust = "0.20.3"
tree-sitter-typescript = "0.20.5"
tree-sitter-yaml = "0.0.1"
unindent = "0.1.7"
unicase = "2.6"
unicode-segmentation = "1.10"
@@ -448,19 +414,20 @@ url = "2.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "serde"] }
wasmparser = "0.201"
wasm-encoder = "0.201"
wasmtime = { version = "21.0.1", default-features = false, features = [
wasmtime = { version = "19.0.0", default-features = false, features = [
"async",
"demangle",
"runtime",
"cranelift",
"component-model",
] }
wasmtime-wasi = "21.0.1"
wasmtime-wasi = "19.0.0"
which = "6.0.0"
wit-component = "0.201"
sys-locale = "0.3.1"
[workspace.dependencies.windows]
version = "0.58"
version = "0.57"
features = [
"implement",
"Foundation_Numerics",
@@ -480,6 +447,7 @@ features = [
"Win32_Security",
"Win32_Security_Credentials",
"Win32_Storage_FileSystem",
"Win32_System_LibraryLoader",
"Win32_System_Com",
"Win32_System_Com_StructuredStorage",
"Win32_System_DataExchange",
@@ -499,8 +467,9 @@ features = [
]
[patch.crates-io]
# Patch Tree-sitter for updated wasmtime.
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7f4a57817d58a2f134fe863674acad6bbf007228" }
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7b4894ba2ae81b988846676f54c0988d4027ef4f" }
# Workaround for a broken nightly build of gpui: See #7644 and revisit once 0.5.3 is released.
pathfinder_simd = { git = "https://github.com/servo/pathfinder.git", rev = "4968e819c0d9b015437ffc694511e175801a17c7" }
[profile.dev]
split-debuginfo = "unpacked"
@@ -553,6 +522,13 @@ single_range_in_vec_init = "allow"
style = { level = "allow", priority = -1 }
# Individual rules that have violations in the codebase:
almost_complete_range = "allow"
arc_with_non_send_sync = "allow"
borrowed_box = "allow"
let_underscore_future = "allow"
map_entry = "allow"
non_canonical_partial_ord_impl = "allow"
reversed_empty_ranges = "allow"
type_complexity = "allow"
[workspace.metadata.cargo-machete]

View File

@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:1.2
FROM rust:1.80-bookworm as builder
FROM rust:1.79-bookworm as builder
WORKDIR app
COPY . .

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye"><path d="M2.062 12.348a1 1 0 0 1 0-.696 10.75 10.75 0 0 1 19.876 0 1 1 0 0 1 0 .696 10.75 10.75 0 0 1-19.876 0"/><circle cx="12" cy="12" r="3"/></svg>

Before

Width:  |  Height:  |  Size: 358 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-code"><path d="M10 12.5 8 15l2 2.5"/><path d="m14 12.5 2 2.5-2 2.5"/><path d="M14 2v4a2 2 0 0 0 2 2h4"/><path d="M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7z"/></svg>

Before

Width:  |  Height:  |  Size: 388 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-text"><path d="M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7Z"/><path d="M14 2v4a2 2 0 0 0 2 2h4"/><path d="M10 9H8"/><path d="M16 13H8"/><path d="M16 17H8"/></svg>

Before

Width:  |  Height:  |  Size: 384 B

View File

@@ -1,6 +0,0 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3 4H8" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 10L11 10" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
<circle cx="4" cy="10" r="1.875" stroke="black" stroke-width="1.75"/>
<circle cx="10" cy="4" r="1.875" stroke="black" stroke-width="1.75"/>
</svg>

Before

Width:  |  Height:  |  Size: 450 B

View File

@@ -1,3 +0,0 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 6C5.69062 6.30938 4.56159 6.55977 3.51192 6.73263C3.27345 6.7719 3.27345 7.2281 3.51192 7.26737C4.56159 7.44023 5.69062 7.69062 6 8C6.30938 8.30938 6.55977 9.43841 6.73263 10.4881C6.7719 10.7266 7.2281 10.7266 7.26737 10.4881C7.44023 9.43841 7.69062 8.30938 8 8C8.30938 7.69062 9.43841 7.44023 10.4881 7.26737C10.7266 7.2281 10.7266 6.7719 10.4881 6.73263C9.43841 6.55977 8.30938 6.30938 8 6C7.69062 5.69062 7.44023 4.56159 7.26737 3.51192C7.2281 3.27345 6.7719 3.27345 6.73263 3.51192C6.55977 4.56159 6.30938 5.69062 6 6Z" stroke="black" stroke-width="1.25" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 700 B

View File

@@ -1,3 +1,3 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 9.8V4.2C4 4.08954 4.08954 4 4.2 4H9.8C9.91046 4 10 4.08954 10 4.2V9.8C10 9.91046 9.91046 10 9.8 10H4.2C4.08954 10 4 9.91046 4 9.8Z" stroke="#C56757" stroke-width="1.25" stroke-linejoin="round"/>
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M9.88889 1H2.11111C1.49746 1 1 1.49746 1 2.11111V9.88889C1 10.5025 1.49746 11 2.11111 11H9.88889C10.5025 11 11 10.5025 11 9.88889V2.11111C11 1.49746 10.5025 1 9.88889 1Z" stroke="#C56757" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 310 B

After

Width:  |  Height:  |  Size: 369 B

View File

@@ -40,6 +40,7 @@
"backspace": "editor::Backspace",
"shift-backspace": "editor::Backspace",
"delete": "editor::Delete",
"ctrl-d": "editor::Delete",
"tab": "editor::Tab",
"shift-tab": "editor::TabPrev",
"ctrl-k": "editor::CutToEndOfLine",
@@ -105,8 +106,8 @@
"bindings": {
"enter": "editor::Newline",
"shift-enter": "editor::Newline",
"ctrl-enter": "editor::NewlineAbove",
"ctrl-shift-enter": "editor::NewlineBelow",
"ctrl-enter": "editor::NewlineAbove",
"alt-z": "editor::ToggleSoftWrap",
"ctrl-f": "buffer_search::Deploy",
"ctrl-h": ["buffer_search::Deploy", { "replace_enabled": true }],
@@ -268,7 +269,6 @@
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
@@ -477,12 +477,6 @@
"ctrl-enter": "assistant::InlineAssist"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"ctrl-shift-enter": "repl::Run"
}
},
{
"context": "ContextEditor > Editor",
"bindings": {
@@ -601,7 +595,6 @@
"ctrl-alt-space": "terminal::ShowCharacterPalette",
"shift-ctrl-c": "terminal::Copy",
"ctrl-insert": "terminal::Copy",
"ctrl-a": "editor::SelectAll",
"shift-ctrl-v": "terminal::Paste",
"shift-insert": "terminal::Paste",
"ctrl-enter": "assistant::InlineAssist",

View File

@@ -135,7 +135,6 @@
"bindings": {
"enter": "editor::Newline",
"shift-enter": "editor::Newline",
"cmd-enter": "editor::NewlineBelow",
"cmd-shift-enter": "editor::NewlineAbove",
"alt-z": "editor::ToggleSoftWrap",
"cmd-f": "buffer_search::Deploy",
@@ -147,6 +146,12 @@
"cmd-alt-e": "editor::SelectEnclosingSymbol"
}
},
{
"context": "Editor && mode == full && !jupyter",
"bindings": {
"cmd-enter": "editor::NewlineBelow"
}
},
{
"context": "Editor && mode == full && inline_completion",
"bindings": {
@@ -175,12 +180,6 @@
"cmd-c": "markdown::Copy"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"ctrl-shift-enter": "repl::Run"
}
},
{
"context": "AssistantPanel",
"bindings": {
@@ -293,6 +292,7 @@
"alt-cmd-c": "search::ToggleCaseSensitive",
"alt-cmd-w": "search::ToggleWholeWord",
"alt-cmd-f": "project_search::ToggleFilters",
"alt-cmd-g": "search::ToggleRegex",
"alt-cmd-x": "search::ToggleRegex"
}
},
@@ -570,6 +570,12 @@
"space": "project_panel::Open"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"cmd-enter": "repl::Run"
}
},
{
"context": "CollabPanel && not_editing",
"bindings": {
@@ -622,7 +628,6 @@
"ctrl-cmd-space": "terminal::ShowCharacterPalette",
"cmd-c": "terminal::Copy",
"cmd-v": "terminal::Paste",
"cmd-a": "editor::SelectAll",
"cmd-k": "terminal::Clear",
"ctrl-enter": "assistant::InlineAssist",
// Some nice conveniences

View File

@@ -12,8 +12,8 @@
{
"context": "Editor",
"bindings": {
"ctrl-shift-up": "editor::MoveLineUp",
"ctrl-shift-down": "editor::MoveLineDown",
"ctrl-shift-up": "editor::AddSelectionAbove",
"ctrl-shift-down": "editor::AddSelectionBelow",
"ctrl-shift-m": "editor::SelectLargerSyntaxNode",
"ctrl-shift-l": "editor::SplitSelectionIntoLines",
"ctrl-shift-a": "editor::SelectLargerSyntaxNode",

View File

@@ -14,8 +14,6 @@
"bindings": {
"ctrl-shift-up": "editor::AddSelectionAbove",
"ctrl-shift-down": "editor::AddSelectionBelow",
"cmd-ctrl-up": "editor::MoveLineUp",
"cmd-ctrl-down": "editor::MoveLineDown",
"cmd-shift-space": "editor::SelectAll",
"ctrl-shift-m": "editor::SelectLargerSyntaxNode",
"cmd-shift-l": "editor::SplitSelectionIntoLines",

View File

@@ -253,7 +253,7 @@
"[ d": "editor::GoToPrevDiagnostic",
"] c": "editor::GoToHunk",
"[ c": "editor::GoToPrevHunk",
"g c": ["vim::PushOperator", "ToggleComments"]
"g c c": "vim::ToggleComments"
}
},
{
@@ -434,12 +434,6 @@
"<": "vim::CurrentLine"
}
},
{
"context": "vim_operator == gc",
"bindings": {
"c": "vim::CurrentLine"
}
},
{
"context": "BufferSearchBar && !in_replace",
"bindings": {

View File

@@ -26,9 +26,6 @@
},
// The name of a font to use for rendering text in the editor
"buffer_font_family": "Zed Plex Mono",
// Set the buffer text's font fallbacks, this will be merged with
// the platform's default fallbacks.
"buffer_font_fallbacks": [],
// The OpenType features to enable for text in the editor.
"buffer_font_features": {
// Disable ligatures:
@@ -50,11 +47,8 @@
// },
"buffer_line_height": "comfortable",
// The name of a font to use for rendering text in the UI
// You can set this to ".SystemUIFont" to use the system font
// (On macOS) You can set this to ".SystemUIFont" to use the system font
"ui_font_family": "Zed Plex Sans",
// Set the UI's font fallbacks, this will be merged with the platform's
// default font fallbacks.
"ui_font_fallbacks": [],
// The OpenType features to enable for text in the UI
"ui_font_features": {
// Disable ligatures:
@@ -88,7 +82,7 @@
// Whether to confirm before quitting Zed.
"confirm_quit": false,
// Whether to restore last closed project when fresh Zed instance is opened.
"restore_on_startup": "last_session",
"restore_on_startup": "last_workspace",
// Size of the drop target in the editor.
"drop_target_size": 0.2,
// Whether the window should be closed when using 'close active item' on a window with no tabs.
@@ -318,7 +312,7 @@
"auto_reveal_entries": true,
// Whether to fold directories automatically and show compact folders
// (e.g. "a/b/c" ) when a directory has only one subdirectory inside.
"auto_fold_dirs": true,
"auto_fold_dirs": false,
/// Scrollbar-related settings
"scrollbar": {
/// When to show the scrollbar in the project panel.
@@ -381,7 +375,7 @@
},
"assistant": {
// Version of this setting.
"version": "2",
"version": "1",
// Whether the assistant is enabled.
"enabled": true,
// Whether to show the assistant panel button in the status bar.
@@ -392,12 +386,18 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// The default model to use when creating new contexts.
"default_model": {
// The provider to use.
"provider": "openai",
// The model to use.
"model": "gpt-4o"
// AI provider.
"provider": {
"name": "openai",
// The default model to use when creating new contexts. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo"
// 2. "gpt-4"
// 3. "gpt-4-turbo-preview"
// 4. "gpt-4o"
// 5. "gpt-4o-mini"
"default_model": "gpt-4o"
}
},
// Whether the screen sharing icon is shown in the os status bar.
@@ -681,10 +681,6 @@
// Set the terminal's font family. If this option is not included,
// the terminal will default to matching the buffer's font family.
// "font_family": "Zed Plex Mono",
// Set the terminal's font fallbacks. If this option is not included,
// the terminal will default to matching the buffer's font fallbacks.
// This will be merged with the platform's default font fallbacks
// "font_fallbacks": ["FiraCode Nerd Fonts"],
// Sets the maximum number of lines in the terminal's scrollback buffer.
// Default: 10_000, maximum: 100_000 (all bigger values set will be treated as 100_000), 0 disables the scrolling.
// Existing terminals will not pick up this change until they are recreated.
@@ -709,7 +705,7 @@
//
"file_types": {
"JSON": ["flake.lock"],
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "tsconfig.json"]
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json"]
},
// The extensions that Zed should automatically install on startup.
//
@@ -747,9 +743,6 @@
"Elixir": {
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
},
"Erlang": {
"language_servers": ["erlang-ls", "!elp", "..."]
},
"Go": {
"code_actions_on_format": {
"source.organizeImports": true
@@ -819,9 +812,6 @@
"plugins": ["prettier-plugin-sql"]
}
},
"Starlark": {
"language_servers": ["starpls", "!buck2-lsp", "..."]
},
"Svelte": {
"prettier": {
"allowed": true,
@@ -862,21 +852,6 @@
}
}
},
// Different settings for specific language models.
"language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
},
"openai": {
"api_url": "https://api.openai.com/v1"
},
"google": {
"api_url": "https://generativelanguage.googleapis.com"
},
"ollama": {
"api_url": "http://localhost:11434"
}
},
// Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier
// and configure default Prettier, used when no project-level Prettier installation is found.
@@ -909,15 +884,6 @@
// }
// }
},
// Jupyter settings
"jupyter": {
"enabled": true
// Specify the language name as the key and the kernel name as the value.
// "kernel_selections": {
// "python": "conda-base"
// "typescript": "deno"
// }
},
// Vim settings
"vim": {
"use_system_clipboard": "always",
@@ -970,29 +936,5 @@
// Examples:
// - "proxy": "socks5://localhost:10808"
// - "proxy": "http://127.0.0.1:10809"
"proxy": null,
// Set to configure aliases for the command palette.
// When typing a query which is a key of this object, the value will be used instead.
//
// Examples:
// {
// "W": "workspace::Save"
// }
"command_aliases": {},
// ssh_connections is an array of ssh connections.
// By default this setting is null, which disables the direct ssh connection support.
// You can configure these from `project: Open Remote` in the command palette.
// Zed's ssh support will pull configuration from your ~/.ssh too.
// Examples:
// [
// {
// "host": "example-box",
// "projects": [
// {
// "paths": ["/home/user/code/zed"]
// }
// ]
// }
// ]
"ssh_connections": null
"proxy": null
}

View File

@@ -1,5 +1,5 @@
// Folder-specific settings
//
// For a full list of overridable settings, and general information on folder-specific settings,
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
// see the documentation: https://zed.dev/docs/configuring-zed#folder-specific-settings
{}

View File

@@ -17,27 +17,6 @@
// What to do with the terminal pane and tab, after the command was started:
// * `always` — always show the terminal pane, add and focus the corresponding task's tab in it (default)
// * `never` — avoid changing current terminal pane focus, but still add/reuse the task's tab there
"reveal": "always",
// What to do with the terminal pane and tab, after the command had finished:
// * `never` — Do nothing when the command finishes (default)
// * `always` — always hide the terminal tab, hide the pane also if it was the last tab in it
// * `on_success` — hide the terminal tab on task success only, otherwise behaves similar to `always`
"hide": "never",
// Which shell to use when running a task inside the terminal.
// May take 3 values:
// 1. (default) Use the system's default terminal configuration in /etc/passwd
// "shell": "system"
// 2. A program:
// "shell": {
// "program": "sh"
// }
// 3. A program with arguments:
// "shell": {
// "with_arguments": {
// "program": "/bin/bash",
// "arguments": ["--login"]
// }
// }
"shell": "system"
"reveal": "always"
}
]

View File

@@ -18,7 +18,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
http.workspace = true
isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true

View File

@@ -1,9 +1,9 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{convert::TryFrom, time::Duration};
use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
@@ -20,8 +20,6 @@ pub enum Model {
Claude3Sonnet,
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
}
impl Model {
@@ -35,88 +33,139 @@ impl Model {
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("invalid model id"))
Err(anyhow!("Invalid model id: {}", id))
}
}
pub fn id(&self) -> &str {
pub fn id(&self) -> &'static str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
Self::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom { name, .. } => name,
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000,
Self::Custom { max_tokens, .. } => *max_tokens,
200_000
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
impl TryFrom<String> for Role {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<Response> {
let uri = format!("{api_url}/v1/messages");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "tools-2024-04-04")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: Response = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
}
}
}
#[derive(Debug, Serialize)]
pub struct Request {
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
message: ResponseMessage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
Ping {},
ContentBlockDelta {
index: u32,
delta: TextDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: ResponseMessage,
usage: Usage,
},
MessageStop {},
}
#[derive(Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
pub id: Option<String>,
pub role: Option<String>,
pub content: Option<Vec<String>>,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<Event>>> {
let request = StreamingRequest {
base: request,
stream: true,
};
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let uri = format!("{api_url}/v1/messages");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
@@ -128,9 +177,7 @@ pub async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -155,7 +202,7 @@ pub async fn stream_completion(
let body_str = std::str::from_utf8(&body)?;
match serde_json::from_str::<Event>(body_str) {
match serde_json::from_str::<ResponseEvent>(body_str) {
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
@@ -169,183 +216,42 @@ pub async fn stream_completion(
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<Event>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
Event::ContentBlockStart { content_block, .. } => match content_block {
Content::Text { text } => Some(Ok(text)),
_ => None,
},
Event::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => Some(Ok(text)),
_ => None,
},
_ => None,
},
Err(error) => Some(Err(error)),
}
})
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use http::IsahcHttpClient;
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
// #[tokio::test]
// async fn stream_completion_success() {
// let http_client = IsahcHttpClient::new().unwrap();
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
// let request = Request {
// model: Model::Claude3Opus,
// messages: vec![RequestMessage {
// role: Role::User,
// content: "Ping".to_string(),
// }],
// stream: true,
// system: "Respond to ping with pong".to_string(),
// max_tokens: 4096,
// };
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Content {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { source: ImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
}
// let stream = stream_completion(
// &http_client,
// "https://api.anthropic.com",
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
// request,
// )
// .await
// .unwrap();
#[derive(Debug, Serialize, Deserialize)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct StreamingRequest {
#[serde(flatten)]
pub base: Request,
pub stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
pub user_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: Role,
pub content: Vec<Content>,
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Event {
#[serde(rename = "message_start")]
MessageStart { message: Response },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: Content,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: ContentDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta { delta: MessageDelta, usage: Usage },
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ApiError },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageDelta {
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
// stream
// .for_each(|event| async {
// match event {
// Ok(event) => println!("{:?}", event),
// Err(e) => eprintln!("Error: {:?}", e),
// }
// })
// .await;
// }
// }

View File

@@ -26,13 +26,13 @@ anyhow.workspace = true
assets.workspace = true
assistant_slash_command.workspace = true
async-watch.workspace = true
breadcrumbs.workspace = true
cargo_toml.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -41,11 +41,10 @@ fuzzy.workspace = true
gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
http.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
@@ -65,19 +64,21 @@ serde_json.workspace = true
settings.workspace = true
similar.workspace = true
smol.workspace = true
strum.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
terminal_view.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
toml.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View File

@@ -38,7 +38,7 @@ Considering these aspects will ensure our conversation view design is optimized
@nate> 2 feels like it isn't important at the moment, we can explore that later. Let's start with 4, which I think will lead us to discussion 3 and 5.
#zed share your thoughts on the points we need to consider to design a layout and visualization for a conversation view between you (#zed) and multiple people, or between multiple people and multiple bots (you and other bots).
#zed share your thoughts on the points we need to consider to design a layout and visualization for a conversation view between you (#zed) and multuple peoople, or between multiple people and multiple bots (you and other bots).
@nathan> Agreed. I'm interested in threading I think more than anything. Or 4 yeah. I think we need to scope the threading conversation. Also, asking #zed to propose the solution... not sure it will be that effective but it's worth a try...

View File

@@ -1,5 +1,6 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod completion_provider;
mod context;
pub mod context_store;
mod inline_assistant;
@@ -11,30 +12,30 @@ mod streaming_diff;
mod terminal_inline_assistant;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
use completion::LanguageModelCompletionProvider;
pub use completion_provider::*;
pub use context::*;
pub use context_store::*;
use fs::Fs;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::{
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use settings::{Settings, SettingsStore};
use slash_command::{
active_command, default_command, diagnostics_command, docs_command, fetch_command,
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tabs_command, term_command,
};
use std::sync::Arc;
use std::{
fmt::{self, Display},
sync::Arc,
};
pub(crate) use streaming_diff::*;
actions!(
@@ -72,6 +73,166 @@ impl MessageId {
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
@@ -165,16 +326,6 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
cx.set_global(Assistant::default());
AssistantSettings::register(cx);
// TODO: remove this when 0.148.0 is released.
if AssistantSettings::get_global(cx).using_outdated_settings_version {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let fs = fs.clone();
|content, cx| {
content.update_file(fs, cx);
}
});
}
cx.spawn(|mut cx| {
let client = client.clone();
async move {
@@ -192,7 +343,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
init_completion_provider(cx);
completion_provider::init(client.clone(), cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@@ -217,38 +368,6 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
fn init_completion_provider(cx: &mut AppContext) {
completion::init(cx);
update_active_language_model_from_settings(cx);
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
.detach();
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
update_active_language_model_from_settings(cx)
})
.detach();
}
fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)
.read(cx)
.provider(&provider_name)
else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
}
fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);

View File

@@ -8,17 +8,18 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandRegistry,
},
terminal_inline_assistant::TerminalInlineAssistant,
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep,
EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant,
InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split,
ToggleFocus, ToggleModelSelector,
};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use breadcrumbs::Breadcrumbs;
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
use completion::LanguageModelCompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
@@ -42,14 +43,12 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset,
};
use language_model::Role;
use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate};
use search::{buffer_search::DivRegistrar, BufferSearchBar};
use settings::Settings;
use std::{
borrow::Cow,
cmp::{self, Ordering},
fmt::Write,
ops::Range,
@@ -58,6 +57,7 @@ use std::{
time::Duration,
};
use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
use theme::ThemeSettings;
use ui::{
prelude::*,
utils::{format_distance_from_now, DateTimeType},
@@ -67,7 +67,7 @@ use ui::{
use util::ResultExt;
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
item::{self, FollowableItem, Item, ItemHandle},
item::{self, BreadcrumbText, FollowableItem, Item, ItemHandle},
notifications::NotifyTaskExt,
pane::{self, SaveIntent},
searchable::{SearchEvent, SearchableItem},
@@ -112,7 +112,6 @@ pub struct AssistantPanel {
subscriptions: Vec<Subscription>,
authentication_prompt: Option<AnyView>,
model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
model_summary_editor: View<Editor>,
}
#[derive(Clone)]
@@ -300,14 +299,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
) -> Self {
let model_selector_menu_handle = PopoverMenuHandle::default();
let model_summary_editor = cx.new_view(|cx| Editor::single_line(cx));
let context_editor_toolbar = cx.new_view(|_| {
ContextEditorToolbarItem::new(
workspace,
model_selector_menu_handle.clone(),
model_summary_editor.clone(),
)
});
let pane = cx.new_view(|cx| {
let mut pane = Pane::new(
workspace.weak_handle(),
@@ -353,7 +344,13 @@ impl AssistantPanel {
.into_any_element()
});
pane.toolbar().update(cx, |toolbar, cx| {
toolbar.add_item(context_editor_toolbar.clone(), cx);
toolbar.add_item(cx.new_view(|_| Breadcrumbs::new()), cx);
toolbar.add_item(
cx.new_view(|_| {
ContextEditorToolbarItem::new(workspace, model_selector_menu_handle.clone())
}),
cx,
);
toolbar.add_item(cx.new_view(BufferSearchBar::new), cx)
});
pane
@@ -362,14 +359,13 @@ impl AssistantPanel {
let subscriptions = vec![
cx.observe(&pane, |_, _, cx| cx.notify()),
cx.subscribe(&pane, Self::handle_pane_event),
cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
cx.observe(
&LanguageModelCompletionProvider::global(cx),
|this, _, cx| {
this.completion_provider_changed(cx);
},
),
cx.observe_global::<CompletionProvider>({
let mut prev_settings_version = CompletionProvider::global(cx).settings_version();
move |this, cx| {
this.completion_provider_changed(prev_settings_version, cx);
prev_settings_version = CompletionProvider::global(cx).settings_version();
}
}),
];
Self {
@@ -384,7 +380,6 @@ impl AssistantPanel {
subscriptions,
authentication_prompt: None,
model_selector_menu_handle,
model_summary_editor,
}
}
@@ -394,19 +389,10 @@ impl AssistantPanel {
event: &pane::Event,
cx: &mut ViewContext<Self>,
) {
let update_model_summary = match event {
pane::Event::Remove => {
cx.emit(PanelEvent::Close);
false
}
pane::Event::ZoomIn => {
cx.emit(PanelEvent::ZoomIn);
false
}
pane::Event::ZoomOut => {
cx.emit(PanelEvent::ZoomOut);
false
}
match event {
pane::Event::Remove => cx.emit(PanelEvent::Close),
pane::Event::ZoomIn => cx.emit(PanelEvent::ZoomIn),
pane::Event::ZoomOut => cx.emit(PanelEvent::ZoomOut),
pane::Event::AddItem { item } => {
self.workspace
@@ -414,7 +400,6 @@ impl AssistantPanel {
item.added_to_pane(workspace, self.pane.clone(), cx)
})
.ok();
true
}
pane::Event::ActivateItem { local } => {
@@ -426,92 +411,45 @@ impl AssistantPanel {
.ok();
}
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
pane::Event::RemoveItem { .. } => {
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
_ => false,
};
_ => {}
}
}
fn completion_provider_changed(
&mut self,
prev_settings_version: usize,
cx: &mut ViewContext<Self>,
) {
if self.is_authenticated(cx) {
self.authentication_prompt = None;
if update_model_summary {
if let Some(editor) = self.active_context_editor(cx) {
self.show_updated_summary(&editor, cx)
}
}
}
fn handle_summary_editor_event(
&mut self,
model_summary_editor: View<Editor>,
event: &EditorEvent,
cx: &mut ViewContext<Self>,
) {
if matches!(event, EditorEvent::Edited { .. }) {
if let Some(context_editor) = self.active_context_editor(cx) {
let new_summary = model_summary_editor.read(cx).text(cx);
context_editor.update(cx, |context_editor, cx| {
context_editor.context.update(cx, |context, cx| {
if context.summary().is_none()
&& (new_summary == DEFAULT_TAB_TITLE || new_summary.trim().is_empty())
{
return;
}
context.custom_summary(new_summary, cx)
});
});
}
}
}
fn handle_toolbar_event(
&mut self,
_: View<ContextEditorToolbarItem>,
_: &ContextEditorToolbarItemEvent,
cx: &mut ViewContext<Self>,
) {
if let Some(context_editor) = self.active_context_editor(cx) {
context_editor.update(cx, |context_editor, cx| {
context_editor.context.update(cx, |context, cx| {
context.summarize(true, cx);
editor.update(cx, |active_context, cx| {
active_context
.context
.update(cx, |context, cx| context.completion_provider_changed(cx))
})
})
}
}
fn completion_provider_changed(&mut self, cx: &mut ViewContext<Self>) {
if let Some(editor) = self.active_context_editor(cx) {
editor.update(cx, |active_context, cx| {
active_context
.context
.update(cx, |context, cx| context.completion_provider_changed(cx))
})
}
if self.active_context_editor(cx).is_none() {
self.new_context(cx);
}
let authentication_prompt = Self::authentication_prompt(cx);
for context_editor in self.context_editors(cx) {
context_editor.update(cx, |editor, cx| {
editor.set_authentication_prompt(authentication_prompt.clone(), cx);
});
}
cx.notify();
}
fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
if !provider.is_authenticated(cx) {
return Some(provider.authentication_prompt(cx));
}
if self.active_context_editor(cx).is_none() {
self.new_context(cx);
}
cx.notify();
} else if self.authentication_prompt.is_none()
|| prev_settings_version != CompletionProvider::global(cx).settings_version()
{
self.authentication_prompt =
Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.authentication_prompt(cx)
}));
cx.notify();
}
None
}
pub fn inline_assist(
@@ -698,43 +636,18 @@ impl AssistantPanel {
.push(cx.subscribe(&context_editor, Self::handle_context_editor_event));
}
self.show_updated_summary(&context_editor, cx);
cx.emit(AssistantPanelEvent::ContextEdited);
cx.notify();
}
fn show_updated_summary(
&self,
context_editor: &View<ContextEditor>,
cx: &mut ViewContext<Self>,
) {
context_editor.update(cx, |context_editor, cx| {
let new_summary = context_editor
.context
.read(cx)
.summary()
.map(|s| s.text.clone())
.unwrap_or_else(|| context_editor.title(cx).to_string());
self.model_summary_editor.update(cx, |summary_editor, cx| {
if summary_editor.text(cx) != new_summary {
summary_editor.set_text(new_summary, cx);
}
});
});
}
fn handle_context_editor_event(
&mut self,
context_editor: View<ContextEditor>,
_: View<ContextEditor>,
event: &EditorEvent,
cx: &mut ViewContext<Self>,
) {
match event {
EditorEvent::TitleChanged => {
self.show_updated_summary(&context_editor, cx);
cx.notify()
}
EditorEvent::TitleChanged { .. } => cx.notify(),
EditorEvent::Edited { .. } => cx.emit(AssistantPanelEvent::ContextEdited),
_ => {}
}
@@ -772,7 +685,7 @@ impl AssistantPanel {
}
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
LanguageModelCompletionProvider::read_global(cx)
CompletionProvider::global(cx)
.reset_credentials(cx)
.detach_and_log_err(cx);
}
@@ -781,13 +694,6 @@ impl AssistantPanel {
self.model_selector_menu_handle.toggle(cx);
}
fn context_editors(&self, cx: &AppContext) -> Vec<View<ContextEditor>> {
self.pane
.read(cx)
.items_of_type::<ContextEditor>()
.collect()
}
fn active_context_editor(&self, cx: &AppContext) -> Option<View<ContextEditor>> {
self.pane
.read(cx)
@@ -909,11 +815,11 @@ impl AssistantPanel {
}
fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
CompletionProvider::global(cx).is_authenticated()
}
fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
}
fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@@ -973,18 +879,14 @@ impl Panel for AssistantPanel {
}
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
settings::update_settings_file::<AssistantSettings>(
self.fs.clone(),
cx,
move |settings, _| {
let dock = match position {
DockPosition::Left => AssistantDockPosition::Left,
DockPosition::Bottom => AssistantDockPosition::Bottom,
DockPosition::Right => AssistantDockPosition::Right,
};
settings.set_dock(dock);
},
);
settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
let dock = match position {
DockPosition::Left => AssistantDockPosition::Left,
DockPosition::Bottom => AssistantDockPosition::Bottom,
DockPosition::Right => AssistantDockPosition::Right,
};
settings.set_dock(dock);
});
}
fn size(&self, cx: &WindowContext) -> Pixels {
@@ -1083,7 +985,6 @@ struct ActiveEditStep {
pub struct ContextEditor {
context: Model<Context>,
authentication_prompt: Option<AnyView>,
fs: Arc<dyn Fs>,
workspace: WeakView<Workspace>,
project: Model<Project>,
@@ -1099,10 +1000,9 @@ pub struct ContextEditor {
assistant_panel: WeakView<AssistantPanel>,
}
const DEFAULT_TAB_TITLE: &str = "New Context";
const MAX_TAB_TITLE_LEN: usize = 16;
impl ContextEditor {
const MAX_TAB_TITLE_LEN: usize = 16;
fn for_context(
context: Model<Context>,
fs: Arc<dyn Fs>,
@@ -1141,7 +1041,6 @@ impl ContextEditor {
let sections = context.read(cx).slash_command_output_sections().to_vec();
let mut this = Self {
context,
authentication_prompt: None,
editor,
lsp_adapter_delegate,
blocks: Default::default(),
@@ -1161,15 +1060,6 @@ impl ContextEditor {
this
}
fn set_authentication_prompt(
&mut self,
authentication_prompt: Option<AnyView>,
cx: &mut ViewContext<Self>,
) {
self.authentication_prompt = authentication_prompt;
cx.notify();
}
fn insert_default_prompt(&mut self, cx: &mut ViewContext<Self>) {
let command_name = DefaultSlashCommand.name();
self.editor.update(cx, |editor, cx| {
@@ -1196,10 +1086,6 @@ impl ContextEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
if self.authentication_prompt.is_some() {
return;
}
if !self.apply_edit_step(cx) {
self.send_to_model(cx);
}
@@ -1207,16 +1093,12 @@ impl ContextEditor {
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
if let Some(step) = self.active_edit_step.as_ref() {
let assist_ids = step.assist_ids.clone();
cx.window_context().defer(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in assist_ids {
assistant.start_assist(assist_id, cx);
}
})
});
!step.assist_ids.is_empty()
InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in &step.assist_ids {
assistant.start_assist(*assist_id, cx);
}
!step.assist_ids.is_empty()
})
} else {
false
}
@@ -1265,7 +1147,11 @@ impl ContextEditor {
.collect::<String>()
));
match &step.operations {
Some(EditStepOperations::Ready(operations)) => {
Some(EditStepOperations::Parsed {
operations,
raw_output,
}) => {
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
output.push_str("Parsed Operations:\n");
for op in operations {
output.push_str(&format!(" {:?}\n", op));
@@ -1429,7 +1315,7 @@ impl ContextEditor {
ContextEvent::SummaryChanged => {
cx.emit(EditorEvent::TitleChanged);
self.context.update(cx, |context, cx| {
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
context.save(None, self.fs.clone(), cx);
});
}
ContextEvent::StreamedCompletion => {
@@ -1774,7 +1660,7 @@ impl ContextEditor {
&editor,
range,
description,
suggestion.initial_insertion,
suggestion.insert_newline,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -1839,7 +1725,7 @@ impl ContextEditor {
inline_assist_suggestions.push((
range,
description,
suggestion.initial_insertion,
suggestion.insert_newline,
));
}
}
@@ -1851,12 +1737,12 @@ impl ContextEditor {
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
cx.update(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
for (range, description, initial_insertion) in inline_assist_suggestions {
for (range, description, insert_newline) in inline_assist_suggestions {
assist_ids.push(assistant.suggest_assist(
&editor,
range,
description,
initial_insertion,
insert_newline,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -2145,18 +2031,16 @@ impl ContextEditor {
}
fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
self.context.update(cx, |context, cx| {
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx)
});
self.context
.update(cx, |context, cx| context.save(None, self.fs.clone(), cx));
}
fn title(&self, cx: &AppContext) -> Cow<str> {
fn title(&self, cx: &AppContext) -> String {
self.context
.read(cx)
.summary()
.map(|summary| summary.text.clone())
.map(Cow::Owned)
.unwrap_or_else(|| Cow::Borrowed(DEFAULT_TAB_TITLE))
.unwrap_or_else(|| "New Context".into())
}
fn render_send_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@@ -2164,7 +2048,7 @@ impl ContextEditor {
let button_text = match self.edit_step_for_cursor(cx) {
Some(edit_step) => match &edit_step.operations {
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
Some(EditStepOperations::Ready(_)) => "Apply Changes",
Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
None => "Send",
},
None => "Send",
@@ -2228,26 +2112,19 @@ impl Render for ContextEditor {
.size_full()
.v_flex()
.child(
if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(authentication_prompt.clone().into_any())
} else {
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone())
.child(
h_flex()
.w_full()
.absolute()
.bottom_0()
.p_4()
.justify_end()
.child(self.render_send_button(cx)),
)
},
div()
.flex_grow()
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone())
.child(
h_flex()
.w_full()
.absolute()
.bottom_0()
.p_4()
.justify_end()
.child(self.render_send_button(cx)),
),
)
}
}
@@ -2262,13 +2139,14 @@ impl Item for ContextEditor {
type Event = editor::EditorEvent;
fn tab_content_text(&self, cx: &WindowContext) -> Option<SharedString> {
Some(util::truncate_and_trailoff(&self.title(cx), MAX_TAB_TITLE_LEN).into())
Some(util::truncate_and_trailoff(&self.title(cx), Self::MAX_TAB_TITLE_LEN).into())
}
fn to_item_events(event: &Self::Event, mut f: impl FnMut(item::ItemEvent)) {
match event {
EditorEvent::Edited { .. } => {
f(item::ItemEvent::Edit);
f(item::ItemEvent::UpdateBreadcrumbs);
}
EditorEvent::TitleChanged => {
f(item::ItemEvent::UpdateTab);
@@ -2278,13 +2156,48 @@ impl Item for ContextEditor {
}
fn tab_tooltip_text(&self, cx: &AppContext) -> Option<SharedString> {
Some(self.title(cx).to_string().into())
Some(self.title(cx).into())
}
fn as_searchable(&self, handle: &View<Self>) -> Option<Box<dyn SearchableItemHandle>> {
Some(Box::new(handle.clone()))
}
fn breadcrumbs(
&self,
theme: &theme::Theme,
cx: &AppContext,
) -> Option<Vec<item::BreadcrumbText>> {
let editor = self.editor.read(cx);
let cursor = editor.selections.newest_anchor().head();
let multibuffer = &editor.buffer().read(cx);
let (_, symbols) = multibuffer.symbols_containing(cursor, Some(&theme.syntax()), cx)?;
let settings = ThemeSettings::get_global(cx);
let mut breadcrumbs = Vec::new();
let title = self.title(cx);
if title.chars().count() > Self::MAX_TAB_TITLE_LEN {
breadcrumbs.push(BreadcrumbText {
text: title,
highlights: None,
font: Some(settings.buffer_font.clone()),
});
}
breadcrumbs.extend(symbols.into_iter().map(|symbol| BreadcrumbText {
text: symbol.text,
highlights: Some(symbol.highlight_ranges),
font: Some(settings.buffer_font.clone()),
}));
Some(breadcrumbs)
}
fn breadcrumb_location(&self) -> ToolbarItemLocation {
ToolbarItemLocation::PrimaryLeft
}
fn set_nav_history(&mut self, nav_history: pane::ItemNavHistory, cx: &mut ViewContext<Self>) {
self.editor.update(cx, |editor, cx| {
Item::set_nav_history(editor, nav_history, cx)
@@ -2492,21 +2405,18 @@ pub struct ContextEditorToolbarItem {
workspace: WeakView<Workspace>,
active_context_editor: Option<WeakView<ContextEditor>>,
model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
model_summary_editor: View<Editor>,
}
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
model_summary_editor: View<Editor>,
) -> Self {
Self {
fs: workspace.app_state().fs.clone(),
workspace: workspace.weak_handle(),
active_context_editor: None,
model_selector_menu_handle,
model_summary_editor,
}
}
@@ -2575,7 +2485,7 @@ impl ContextEditorToolbarItem {
}
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = CompletionProvider::global(cx).model();
let context = &self
.active_context_editor
.as_ref()?
@@ -2614,68 +2524,14 @@ impl ContextEditorToolbarItem {
impl Render for ContextEditorToolbarItem {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let left_side = h_flex()
.gap_2()
.flex_1()
.min_w(rems(DEFAULT_TAB_TITLE.len() as f32))
.when(self.active_context_editor.is_some(), |left_side| {
left_side
.child(
IconButton::new("regenerate-context", IconName::ArrowCircle)
.tooltip(|cx| Tooltip::text("Regenerate Summary", cx))
.on_click(cx.listener(move |_, _, cx| {
cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
})),
)
.child(self.model_summary_editor.clone())
});
let right_side = h_flex()
.gap_2()
.child(
ModelSelector::new(
self.fs.clone(),
ButtonLike::new("active-model")
.style(ButtonStyle::Subtle)
.child(
h_flex()
.w_full()
.gap_0p5()
.child(
div()
.overflow_x_hidden()
.flex_grow()
.whitespace_nowrap()
.child(
Label::new(
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
)
.tooltip(move |cx| {
Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
}),
)
.with_handle(self.model_selector_menu_handle.clone()),
)
.children(self.render_remaining_tokens(cx))
.child(self.render_inject_context_menu(cx));
h_flex()
.size_full()
.justify_between()
.child(left_side)
.child(right_side)
.gap_2()
.child(ModelSelector::new(
self.model_selector_menu_handle.clone(),
self.fs.clone(),
))
.children(self.render_remaining_tokens(cx))
.child(self.render_inject_context_menu(cx))
}
}
@@ -2703,11 +2559,6 @@ impl ToolbarItemView for ContextEditorToolbarItem {
impl EventEmitter<ToolbarItemEvent> for ContextEditorToolbarItem {}
enum ContextEditorToolbarItemEvent {
RegenerateSummary,
}
impl EventEmitter<ContextEditorToolbarItemEvent> for ContextEditorToolbarItem {}
pub struct ContextHistory {
picker: View<Picker<SavedContextPickerDelegate>>,
_subscriptions: Vec<Subscription>,
@@ -2901,7 +2752,7 @@ fn make_lsp_adapter_delegate(
project.update(cx, |project, cx| {
// TODO: Find the right worktree.
let worktree = project
.worktrees(cx)
.worktrees()
.next()
.ok_or_else(|| anyhow!("no worktrees when constructing ProjectLspAdapterDelegate"))?;
Ok(ProjectLspAdapterDelegate::new(project, &worktree, cx) as Arc<dyn LspAdapterDelegate>)

View File

@@ -1,14 +1,166 @@
use std::sync::Arc;
use std::fmt;
use anthropic::Model as AnthropicModel;
use fs::Fs;
use gpui::{AppContext, Pixels};
use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel};
use ollama::Model as OllamaModel;
use open_ai::Model as OpenAiModel;
use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsSources};
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Gpt4OmniMini,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Gpt4OmniMini => "gpt-4o-mini",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Gpt4OmniMini => "GPT 4 Omni Mini",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Gpt4OmniMini => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@@ -19,9 +171,43 @@ pub enum AssistantDockPosition {
Bottom,
}
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
ZedDotDev {
model: CloudModel,
},
OpenAi {
model: OpenAiModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
available_models: Vec<OpenAiModel>,
},
Anthropic {
model: AnthropicModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
Ollama {
model: OllamaModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::OpenAi {
model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContentV1 {
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev { default_model: Option<CloudModel> },
#[serde(rename = "openai")]
@@ -52,8 +238,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: AssistantDefaultModel,
pub using_outdated_settings_version: bool,
pub provider: AssistantProvider,
}
/// Assistant panel settings
@@ -85,142 +270,34 @@ impl Default for AssistantSettingsContent {
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
fn upgrade(&self) -> AssistantSettingsContentV1 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(_) => true,
}
}
pub fn update_file(&mut self, fs: Arc<dyn Fs>, cx: &AppContext) {
if let AssistantSettingsContent::Versioned(settings) = self {
if let VersionedAssistantSettingsContent::V1(settings) = settings {
if let Some(provider) = settings.provider.clone() {
match provider {
AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.anthropic.is_none() {
content.anthropic =
Some(language_model::settings::AnthropicSettingsContent {
api_url,
low_speed_timeout_in_seconds,
..Default::default()
});
}
},
),
AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.ollama.is_none() {
content.ollama =
Some(language_model::settings::OllamaSettingsContent {
api_url,
low_speed_timeout_in_seconds,
});
}
},
),
AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.openai.is_none() {
content.openai =
Some(language_model::settings::OpenAiSettingsContent {
api_url,
low_speed_timeout_in_seconds,
available_models,
});
}
},
),
_ => {}
}
}
}
}
*self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
self.upgrade(),
));
}
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_width,
default_model: settings
.provider
.clone()
.and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel {
provider: "zed.dev".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "openai".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "anthropic".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "ollama".to_string(),
model: model.id().to_string(),
})
}
}),
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
enabled: None,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
default_model: Some(AssistantDefaultModel {
provider: "openai".to_string(),
model: settings
.default_open_ai_model
.clone()
.unwrap_or_default()
.id()
.to_string(),
}),
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProviderContent::OpenAi {
default_model: settings.default_open_ai_model.clone(),
api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProviderContent::OpenAi {
default_model: Some(open_ai_model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
}
})
},
},
}
}
@@ -231,9 +308,6 @@ impl AssistantSettingsContent {
VersionedAssistantSettingsContent::V1(settings) => {
settings.dock = Some(dock);
}
VersionedAssistantSettingsContent::V2(settings) => {
settings.dock = Some(dock);
}
},
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
@@ -241,76 +315,74 @@ impl AssistantSettingsContent {
}
}
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
let provider = language_model.provider_id().0.to_string();
pub fn set_model(&mut self, new_model: LanguageModel) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model);
}
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
});
Some(AssistantProviderContent::OpenAi {
default_model: model,
..
}) => {
if let LanguageModel::OpenAi(new_model) = new_model {
*model = Some(new_model);
}
}
"ollama" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model)),
api_url,
low_speed_timeout_in_seconds,
});
Some(AssistantProviderContent::Anthropic {
default_model: model,
..
}) => {
if let LanguageModel::Anthropic(new_model) = new_model {
*model = Some(new_model);
}
}
"openai" => {
let (api_url, low_speed_timeout_in_seconds, available_models) =
match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
}) => (
api_url.clone(),
*low_speed_timeout_in_seconds,
available_models.clone(),
),
_ => (None, None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: open_ai::Model::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
available_models,
});
Some(AssistantProviderContent::Ollama {
default_model: model,
..
}) => {
if let LanguageModel::Ollama(new_model) = new_model {
*model = Some(new_model);
}
}
_ => {}
provider => match new_model {
LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
}
LanguageModel::OpenAi(model) => {
*provider = Some(AssistantProviderContent::OpenAi {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
}
LanguageModel::Anthropic(model) => {
*provider = Some(AssistantProviderContent::Anthropic {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
LanguageModel::Ollama(model) => {
*provider = Some(AssistantProviderContent::Ollama {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
},
},
VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model });
}
},
AssistantSettingsContent::Legacy(settings) => {
if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) {
if let LanguageModel::OpenAi(model) = new_model {
settings.default_open_ai_model = Some(model);
}
}
@@ -323,78 +395,21 @@ impl AssistantSettingsContent {
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
#[serde(rename = "2")]
V2(AssistantSettingsContentV2),
}
impl Default for VersionedAssistantSettingsContent {
fn default() -> Self {
Self::V2(AssistantSettingsContentV2 {
Self::V1(AssistantSettingsContentV1 {
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
default_model: None,
provider: None,
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV2 {
/// Whether the Assistant is enabled.
///
/// Default: true
enabled: Option<bool>,
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the assistant.
///
/// Default: right
dock: Option<AssistantDockPosition>,
/// Default width in pixels when the assistant is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the assistant is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel {
#[schemars(schema_with = "providers_schema")]
pub provider: String,
pub model: String,
}
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
]),
..Default::default()
}
.into()
}
impl Default for AssistantDefaultModel {
fn default() -> Self {
Self {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
@@ -421,7 +436,7 @@ pub struct AssistantSettingsContentV1 {
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProviderContentV1>,
provider: Option<AssistantProviderContent>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -464,10 +479,6 @@ impl Settings for AssistantSettings {
let mut settings = AssistantSettings::default();
for value in sources.defaults_and_customizations() {
if value.is_version_outdated() {
settings.using_outdated_settings_version = true;
}
let value = value.upgrade();
merge(&mut settings.enabled, value.enabled);
merge(&mut settings.button, value.button);
@@ -480,10 +491,123 @@ impl Settings for AssistantSettings {
&mut settings.default_height,
value.default_height.map(Into::into),
);
merge(
&mut settings.default_model,
value.default_model.map(Into::into),
);
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { model },
AssistantProviderContent::ZedDotDev {
default_model: model_override,
},
) => {
merge(model, model_override);
}
(
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
AssistantProviderContent::OpenAi {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
available_models: available_models_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
merge(available_models, available_models_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(provider, provider_override) => {
*provider = match provider_override {
AssistantProviderContent::ZedDotDev {
default_model: model,
} => AssistantProvider::ZedDotDev {
model: model.unwrap_or_default(),
},
AssistantProviderContent::OpenAi {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => AssistantProvider::OpenAi {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds,
available_models: available_models.unwrap_or_default(),
},
AssistantProviderContent::Anthropic {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Anthropic {
model: model.unwrap_or_default(),
api_url: api_url
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Ollama {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
low_speed_timeout_in_seconds,
},
};
}
}
}
}
Ok(settings)
@@ -496,103 +620,96 @@ fn merge<T>(target: &mut T, value: Option<T>) {
}
}
// #[cfg(test)]
// mod tests {
// use gpui::{AppContext, UpdateGlobal};
// use settings::SettingsStore;
#[cfg(test)]
mod tests {
use gpui::{AppContext, UpdateGlobal};
use settings::SettingsStore;
// use super::*;
use super::*;
// #[gpui::test]
// fn test_deserialize_assistant_settings(cx: &mut AppContext) {
// let store = settings::SettingsStore::test(cx);
// cx.set_global(store);
#[gpui::test]
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
// // Settings default to gpt-4-turbo.
// AssistantSettings::register(cx);
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// Settings default to gpt-4-turbo.
AssistantSettings::register(cx);
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// // Ensure backward-compatibility.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "openai_api_url": "test-url",
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: "test-url".into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "default_open_ai_model": "gpt-4-0613"
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::Four,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// Ensure backward-compatibility.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"openai_api_url": "test-url",
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"default_open_ai_model": "gpt-4-0613"
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::Four,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// // The new version supports setting a custom model when using zed.dev.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "version": "1",
// "provider": {
// "name": "zed.dev",
// "default_model": {
// "custom": {
// "name": "custom-provider"
// }
// }
// }
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::ZedDotDev {
// model: CloudModel::Custom {
// name: "custom-provider".into(),
// max_tokens: None
// }
// }
// );
// }
// }
// The new version supports setting a custom model when using zed.dev.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
}
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: CloudModel::Custom("custom".into())
}
);
}
}

View File

@@ -0,0 +1,396 @@
mod anthropic;
mod cloud;
#[cfg(any(test, feature = "test-support"))]
mod fake;
mod ollama;
mod open_ai;
pub use anthropic::*;
pub use cloud::*;
#[cfg(any(test, feature = "test-support"))]
pub use fake::*;
pub use ollama::*;
pub use open_ai::*;
use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let provider = create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.update_settings(settings_version, cx);
})
})
.detach();
}
pub struct CompletionResponse {
inner: BoxStream<'static, Result<String>>,
_lock: SemaphoreGuardArc,
}
impl futures::Stream for CompletionResponse {
type Item = Result<String>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
pub trait LanguageModelCompletionProvider: Send + Sync {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
fn model(&self) -> LanguageModel;
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>>;
fn stream_completion(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct CompletionProvider {
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
request_limiter: Arc<Semaphore>,
}
impl CompletionProvider {
pub fn new(
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
) -> Self {
Self {
provider,
client,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
self.provider.read().available_models(cx)
}
pub fn settings_version(&self) -> usize {
self.provider.read().settings_version()
}
pub fn is_authenticated(&self) -> bool {
self.provider.read().is_authenticated()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().authenticate(cx)
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
self.provider.read().authentication_prompt(cx)
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().reset_credentials(cx)
}
pub fn model(&self) -> LanguageModel {
self.provider.read().model()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
self.provider.read().count_tokens(request, cx)
}
pub fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<CompletionResponse>> {
let rate_limiter = self.request_limiter.clone();
let provider = self.provider.clone();
cx.foreground_executor().spawn(async move {
let lock = rate_limiter.acquire_arc().await;
let response = provider.read().stream_completion(request);
let response = response.await?;
Ok(CompletionResponse {
inner: response,
_lock: lock,
})
})
}
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
let response = self.stream_completion(request, cx);
cx.foreground_executor().spawn(async move {
let mut chunks = response.await?;
let mut completion = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
completion.push_str(&chunk);
}
Ok(completion)
})
}
}
impl gpui::Global for CompletionProvider {}
impl CompletionProvider {
pub fn global(cx: &AppContext) -> &Self {
cx.global::<Self>()
}
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
&mut self,
update: impl FnOnce(&mut T) -> R,
) -> Option<R> {
let mut provider = self.provider.write();
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
Some(update(provider))
} else {
None
}
}
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => self
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
if let Some(client) = self.client.clone() {
self.provider = create_provider_from_settings(client, version, cx);
} else {
log::warn!("completion provider cannot be created because client is not set");
}
}
}
}
fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use gpui::AppContext;
use parking_lot::RwLock;
use settings::SettingsStore;
use smol::stream::StreamExt;
use crate::{
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
FakeCompletionProvider, LanguageModelRequest,
};
#[gpui::test]
fn test_rate_limiting(cx: &mut AppContext) {
SettingsStore::test(cx);
let fake_provider = FakeCompletionProvider::setup_test(cx);
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.stream_completion(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let mut stream = response.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_provider
.pending_completions()
.into_iter()
.next()
.unwrap();
fake_provider.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_provider.pending_completions() {
fake_provider.finish_completion(&request);
}
assert_eq!(fake_provider.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_provider.pending_completions() {
fake_provider.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_provider.completion_count(), 0);
}
}

View File

@@ -0,0 +1,367 @@
use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role,
};
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
pub struct AnthropicCompletionProvider {
api_key: Option<String>,
api_url: String,
model: AnthropicModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
}
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
AnthropicModel::iter()
.map(LanguageModel::Anthropic)
.collect()
}
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
}
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = None;
});
})
})
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
fn model(&self) -> LanguageModel {
LanguageModel::Anthropic(self.model.clone())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, cx.background_executor())
}
fn stream_completion(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone();
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(response) => match response {
anthropic::ResponseEvent::ContentBlockStart {
content_block, ..
} => match content_block {
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
},
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
match delta {
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
}
}
_ => None,
},
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl AnthropicCompletionProvider {
pub fn new(
model: AnthropicModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: AnthropicModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let model = match request.model {
LanguageModel::Anthropic(model) => model,
_ => self.model.clone(),
};
let mut system_message = String::new();
if request
.messages
.first()
.map_or(false, |message| message.role == Role::System)
{
system_message = request.messages.remove(0).content;
}
Request {
model,
messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true,
system: system_message,
max_tokens: 4092,
}
}
}
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
let mut system_message = String::new();
for message in request.messages.drain(..) {
if message.content.is_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == message.role {
last_message.content.push_str("\n\n");
last_message.content.push_str(&message.content);
continue;
}
}
new_messages.push(message);
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
}
}
}
if !system_message.is_empty() {
new_messages.insert(
0,
LanguageModelRequestMessage {
role: Role::System,
content: system_message,
},
);
}
request.messages = new_messages;
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text(
"sk-000000000000000000000000000000000000000000000000",
cx,
);
editor
}),
api_url,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.api_key,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
}
impl Render for AuthenticationPrompt {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const INSTRUCTIONS: [&str; 4] = [
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
"You can create an API key at: https://console.anthropic.com/settings/keys",
"",
"Paste your Anthropic API key below and hit enter to use the assistant:",
];
v_flex()
.p_4()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.children(
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
.child(
h_flex()
.gap_2()
.child(Label::new("Click on").size(LabelSize::Small))
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
.child(
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
),
)
.into_any()
}
}

View File

@@ -0,0 +1,208 @@
use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelCompletionProvider, LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
pub struct CloudCompletionProvider {
client: Arc<Client>,
model: CloudModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
}
impl CloudCompletionProvider {
pub fn new(
model: CloudModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.status = status;
});
});
}
});
Self {
client,
model,
settings_version,
status,
_maintain_client_status: maintain_client_status,
}
}
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
self.model = model;
self.settings_version = settings_version;
}
}
impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
None
};
CloudModel::iter()
.filter_map(move |model| {
if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?))
} else {
Some(model)
}
})
.map(LanguageModel::Cloud)
.collect()
}
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
self.status.is_connected()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::Cloud(self.model.clone())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
LanguageModel::Cloud(CloudModel::Gpt4)
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(
CloudModel::Claude3_5Sonnet
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku,
) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(CloudModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
});
async move {
let response = request.await?;
Ok(response.token_count as usize)
}
.boxed()
}
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
}
}
fn stream_completion(
&self,
mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
request.preprocess();
let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
};
self.client
.request_stream(request)
.map_ok(|stream| {
stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed()
})
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
struct AuthenticationPrompt;
impl Render for AuthenticationPrompt {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
v_flex()
.gap_2()
.child(
Button::new("sign_in", "Sign in")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.style(ButtonStyle::Filled)
.full_width()
.on_click(|_, cx| {
CompletionProvider::global(cx)
.authenticate(cx)
.detach_and_log_err(cx);
}),
)
.child(
div().flex().w_full().items_center().child(
Label::new("Sign in to enable collaboration.")
.color(Color::Muted)
.size(LabelSize::Small),
),
),
)
}
}

View File

@@ -0,0 +1,115 @@
use anyhow::Result;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, Task};
use std::sync::Arc;
use ui::WindowContext;
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
#[derive(Clone, Default)]
pub struct FakeCompletionProvider {
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
}
impl FakeCompletionProvider {
pub fn setup_test(cx: &mut AppContext) -> Self {
use crate::CompletionProvider;
use parking_lot::RwLock;
let this = Self::default();
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
cx.set_global(provider);
this
}
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
.keys()
.map(|k| serde_json::from_str(k).unwrap())
.collect()
}
pub fn completion_count(&self) -> usize {
self.current_completion_txs.lock().len()
}
pub fn send_completion_chunk(&self, request: &LanguageModelRequest, chunk: String) {
let json = serde_json::to_string(request).unwrap();
self.current_completion_txs
.lock()
.get(&json)
.unwrap()
.unbounded_send(chunk)
.unwrap();
}
pub fn send_last_completion_chunk(&self, chunk: String) {
self.send_completion_chunk(self.pending_completions().last().unwrap(), chunk);
}
pub fn finish_completion(&self, request: &LanguageModelRequest) {
self.current_completion_txs
.lock()
.remove(&serde_json::to_string(request).unwrap())
.unwrap();
}
pub fn finish_last_completion(&self) {
self.finish_completion(self.pending_completions().last().unwrap());
}
}
impl LanguageModelCompletionProvider for FakeCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
vec![LanguageModel::default()]
}
fn settings_version(&self) -> usize {
0
}
fn is_authenticated(&self) -> bool {
true
}
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
unimplemented!()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::default()
}
fn count_tokens(
&self,
_request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
futures::future::ready(Ok(0)).boxed()
}
fn stream_completion(
&self,
_request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs
.lock()
.insert(serde_json::to_string(&_request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

View File

@@ -1,165 +1,50 @@
use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::Result;
use futures::StreamExt as _;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Role as OllamaRole,
};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::sync::Arc;
use std::time::Duration;
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
}
pub struct OllamaLanguageModelProvider {
pub struct OllamaCompletionProvider {
api_url: String,
model: OllamaModel,
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
available_models: Vec<OllamaModel>,
}
struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
_subscription: Subscription,
}
impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.fetch_models(cx).detach();
cx.notify();
}),
}),
};
this.fetch_models(cx).detach();
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
}
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
self.state
.read(cx)
.available_models
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
self.available_models
.iter()
.map(|model| {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
}) as Arc<dyn LanguageModel>
})
.map(|m| LanguageModel::Ollama(m.clone()))
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty()
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
@@ -167,9 +52,14 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
let state = self.state.clone();
let fetch_models = Box::new(move |cx: &mut WindowContext| {
state.update(cx, |this, cx| this.fetch_models(cx))
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.fetch_models(cx)
})
.unwrap_or_else(|| Task::ready(Ok(())))
})
});
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
@@ -179,68 +69,9 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
}
}
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
}
impl OllamaLanguageModel {
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
ChatRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.content,
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
},
Role::System => ChatMessage::System {
content: msg.content,
},
})
.collect(),
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop),
temperature: Some(request.temperature),
..Default::default()
}),
}
}
}
impl LanguageModel for OllamaLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
fn model(&self) -> LanguageModel {
LanguageModel::Ollama(self.model.clone())
}
fn count_tokens(
@@ -263,18 +94,12 @@ impl LanguageModel for OllamaLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
@@ -299,15 +124,151 @@ impl LanguageModel for OllamaLanguageModel {
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl OllamaCompletionProvider {
pub fn new(
model: OllamaModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) -> Self {
cx.spawn({
let api_url = api_url.clone();
let client = http_client.clone();
let model = model.name.clone();
|_| async move {
if model.is_empty() {
return Ok(());
}
preload_model(client.as_ref(), &api_url, &model).await
}
})
.detach_and_log_err(cx);
Self {
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
available_models: Default::default(),
}
}
pub fn update(
&mut self,
model: OllamaModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) {
cx.spawn({
let api_url = api_url.clone();
let client = self.http_client.clone();
let model = model.name.clone();
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
})
.detach_and_log_err(cx);
if model.name.is_empty() {
self.select_first_available_model()
} else {
self.model = model;
}
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
pub fn select_first_available_model(&mut self) {
if let Some(model) = self.available_models.first() {
self.model = model.clone();
}
}
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<OllamaModel> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| OllamaModel::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.available_models = models;
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
provider.select_first_available_model()
}
});
})
})
}
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
let model = match request.model {
LanguageModel::Ollama(model) => model,
_ => self.model.clone(),
};
ChatRequest {
model: model.name,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.content,
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
},
Role::System => ChatMessage::System {
content: msg.content,
},
})
.collect(),
keep_alive: model.keep_alive.unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(model.max_tokens),
stop: Some(request.stop),
temperature: Some(request.temperature),
..Default::default()
}),
}
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
}
}
}

View File

@@ -0,0 +1,378 @@
use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
pub struct OpenAiCompletionProvider {
api_key: Option<String>,
api_url: String,
model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
}
impl OpenAiCompletionProvider {
pub fn new(
model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: OpenAiModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
_ => self.model.clone(),
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
}
}
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
return available_models
.iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect();
}
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
vec![self.model.clone()]
} else {
OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect()
};
available_models
.into_iter()
.map(LanguageModel::OpenAi)
.collect()
}
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = Some(api_key);
});
})
})
}
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = None;
});
})
})
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
fn model(&self) -> LanguageModel {
LanguageModel::OpenAi(self.model.clone())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, cx.background_executor())
}
fn stream_completion(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
background_executor: &gpui::BackgroundExecutor,
) -> BoxFuture<'static, Result<usize>> {
background_executor
.spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.content),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
match request.model {
LanguageModel::Anthropic(_)
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
}
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
}
})
.boxed()
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text(
"sk-000000000000000000000000000000000000000000000000",
cx,
);
editor
}),
api_url,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.api_key,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
}
impl Render for AuthenticationPrompt {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const INSTRUCTIONS: [&str; 6] = [
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
" - You can create an API key at: platform.openai.com/api-keys",
" - Make sure your OpenAI account has credits",
" - Having a subscription for another service like GitHub Copilot won't work.",
"",
"Paste your OpenAI API key below and hit enter to use the assistant:",
];
v_flex()
.p_4()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.children(
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
.child(
h_flex()
.gap_2()
.child(Label::new("Click on").size(LabelSize::Small))
.child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
.child(
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
),
)
.into_any()
}
}

View File

@@ -1,15 +1,15 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
LanguageModelCompletionProvider, MessageId, MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider,
LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
};
use client::{self, proto, telemetry::Telemetry};
use client::{proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use fs::{Fs, RemoveOptions};
use fs::Fs;
use futures::{
future::{self, Shared},
FutureExt, StreamExt,
@@ -18,11 +18,9 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
cmp,
@@ -352,7 +350,7 @@ pub struct EditSuggestion {
pub range: Range<language::Anchor>,
/// If None, assume this is a suggestion to delete the range rather than transform it.
pub description: Option<String>,
pub initial_insertion: Option<InitialInsertion>,
pub insert_newline: bool,
}
impl EditStep {
@@ -361,7 +359,7 @@ impl EditStep {
project: &Model<Project>,
cx: &AppContext,
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
return Task::ready(HashMap::default());
};
@@ -471,28 +469,32 @@ impl EditStep {
}
pub enum EditStepOperations {
Pending(Task<Option<()>>),
Ready(Vec<EditOperation>),
Pending(Task<Result<()>>),
Parsed {
operations: Vec<EditOperation>,
raw_output: String,
},
}
impl Debug for EditStepOperations {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
EditStepOperations::Ready(operations) => f
EditStepOperations::Parsed {
operations,
raw_output,
} => f
.debug_struct("EditStepOperations::Parsed")
.field("operations", operations)
.field("raw_output", raw_output)
.finish(),
}
}
}
/// A description of an operation to apply to one location in the codebase.
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EditOperation {
/// The path to the file containing the relevant operation
pub path: String,
#[serde(flatten)]
pub kind: EditOperationKind,
}
@@ -519,7 +521,7 @@ impl EditOperation {
parse_status.changed().await?;
}
let initial_insertion = kind.initial_insertion();
let insert_newline = kind.insert_newline();
let suggestion_range = if let Some(symbol) = kind.symbol() {
let outline = buffer
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@@ -528,21 +530,7 @@ impl EditOperation {
.path_candidates
.iter()
.find(|item| item.string == symbol)
.with_context(|| {
format!(
"symbol {:?} not found in path {:?}.\ncandidates: {:?}.\nparse status: {:?}. text:\n{}",
symbol,
path,
outline
.path_candidates
.iter()
.map(|candidate| &candidate.string)
.collect::<Vec<_>>(),
*parse_status.borrow(),
buffer.read_with(&cx, |buffer, _| buffer.text()).unwrap_or_else(|_| "error".to_string())
)
})?;
.context("symbol not found")?;
buffer.update(&mut cx, |buffer, _| {
let outline_item = &outline.items[candidate.id];
let symbol_range = outline_item.range.to_point(buffer);
@@ -597,61 +585,39 @@ impl EditOperation {
EditSuggestion {
range: suggestion_range,
description: kind.description().map(ToString::to_string),
initial_insertion,
insert_newline,
},
))
})
}
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
#[serde(tag = "kind")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EditOperationKind {
/// Rewrite the specified symbol in its entirely based on the given description.
Update {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Create a new file with the given path based on the given description.
Create {
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description before the specified symbol.
InsertSiblingBefore {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description after the specified symbol.
InsertSiblingAfter {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the start.
PrependChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the end.
AppendChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Delete the specified symbol.
Delete {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
},
}
@@ -681,13 +647,13 @@ impl EditOperationKind {
}
}
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
pub fn insert_newline(&self) -> bool {
match self {
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
_ => None,
Self::PrependChild { .. }
| Self::AppendChild { .. }
| Self::InsertSiblingAfter { .. }
| Self::InsertSiblingBefore { .. } => true,
_ => false,
}
}
}
@@ -1156,14 +1122,14 @@ impl Context {
.await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
})
})?;
anyhow::Ok(())
}
.log_err()
});
@@ -1319,24 +1285,7 @@ impl Context {
&self,
edit_step: &EditStep,
cx: &mut ModelContext<Self>,
) -> Task<Option<()>> {
#[derive(Debug, Deserialize, JsonSchema)]
struct EditTool {
/// A sequence of operations to apply to the codebase.
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
operations: Vec<EditOperation>,
}
impl LanguageModelTool for EditTool {
fn name() -> String {
"edit".into()
}
fn description() -> String {
"suggest edits to one or more locations in the codebase".into()
}
}
) -> Task<Result<()>> {
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@@ -1345,41 +1294,158 @@ impl Context {
.text_for_range(edit_step_range.clone())
.collect::<String>();
cx.spawn(|this, mut cx| {
async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
cx.spawn(|this, mut cx| async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
let tool_use = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
let raw_output = cx
.update(|cx| CompletionProvider::global(cx).complete(request, cx))?
.await?;
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
}
.log_err()
let operations = Self::parse_edit_operations(&raw_output);
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Parsed {
operations,
raw_output,
});
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
})
}
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
let Some(start_ix) = xml.find("<operations>") else {
return Vec::new();
};
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
return Vec::new();
};
let end_ix = end_ix + start_ix + "</operations>".len();
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
doc.map_or(Vec::new(), |doc| {
doc.root_element()
.children()
.map(|node| {
let tag_name = node.tag_name().name();
let path = node
.attribute("path")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'path'")
})?
.to_string();
let kind = match tag_name {
"update" => EditOperationKind::Update {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"create" => EditOperationKind::Create {
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"prepend_child" => EditOperationKind::PrependChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"append_child" => EditOperationKind::AppendChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"delete" => EditOperationKind::Delete {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
},
_ => return Err(anyhow!("invalid node {node:?}")),
};
anyhow::Ok(EditOperation { path, kind })
})
.filter_map(|op| op.log_err())
.collect()
})
}
@@ -1544,14 +1610,13 @@ impl Context {
.then_some(message.id)
})?;
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
if !CompletionProvider::global(cx).is_authenticated() {
log::info!("completion provider has no credentials");
return None;
}
let request = self.to_completion_request(cx);
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -1608,7 +1673,7 @@ impl Context {
this.update(&mut cx, |this, cx| {
this.pending_completions
.retain(|completion| completion.id != this.completion_count);
this.summarize(false, cx);
this.summarize(cx);
})?;
anyhow::Ok(())
@@ -1631,14 +1696,11 @@ impl Context {
});
if let Some(telemetry) = this.telemetry.as_ref() {
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let model = CompletionProvider::global(cx).model();
telemetry.report_assistant_event(
Some(this.id.0.clone()),
AssistantKind::Panel,
model_telemetry_id,
model.telemetry_id(),
response_latency,
error_message,
);
@@ -1663,6 +1725,7 @@ impl Context {
.map(|message| message.to_request_message(self.buffer.read(cx)));
LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
@@ -1903,9 +1966,9 @@ impl Context {
self.message_anchors.insert(insertion_ix, new_anchor);
}
pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
if !CompletionProvider::global(cx).is_authenticated() {
return;
}
@@ -1917,29 +1980,24 @@ impl Context {
content: "Summarize the context into a short title without punctuation.".into(),
}));
let request = LanguageModelRequest {
model: CompletionProvider::global(cx).model(),
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
};
let stream =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
let stream = CompletionProvider::global(cx).stream_completion(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
let mut replaced = !replace_old;
while let Some(message) = messages.next().await {
let text = message?;
let mut lines = text.lines();
this.update(&mut cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
let summary = this.summary.get_or_insert(ContextSummary::default());
if !replaced && replace_old {
summary.text.clear();
replaced = true;
}
let summary = this.summary.get_or_insert(Default::default());
summary.text.extend(lines.next());
summary.timestamp = timestamp;
let operation = ContextOperation::UpdateSummary {
@@ -2082,52 +2140,35 @@ impl Context {
if let Some(summary) = summary {
let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
let mut discriminant = 1;
let mut new_path;
loop {
new_path = contexts_dir().join(&format!(
"{} - {}.zed.json",
summary.trim(),
discriminant
));
if fs.is_file(&new_path).await {
discriminant += 1;
} else {
break;
let path = if let Some(old_path) = old_path {
old_path
} else {
let mut discriminant = 1;
let mut new_path;
loop {
new_path = contexts_dir().join(&format!(
"{} - {}.zed.json",
summary.trim(),
discriminant
));
if fs.is_file(&new_path).await {
discriminant += 1;
} else {
break;
}
}
}
new_path
};
fs.create_dir(contexts_dir().as_ref()).await?;
fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
fs.atomic_write(path.clone(), serde_json::to_string(&context).unwrap())
.await?;
if let Some(old_path) = old_path {
if new_path != old_path {
fs.remove_file(
&old_path,
RemoveOptions {
recursive: false,
ignore_if_not_exists: true,
},
)
.await?;
}
}
this.update(&mut cx, |this, _| this.path = Some(new_path))?;
this.update(&mut cx, |this, _| this.path = Some(path))?;
}
Ok(())
});
}
pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
let timestamp = self.next_timestamp();
let summary = self.summary.get_or_insert(ContextSummary::default());
summary.timestamp = timestamp;
summary.done = true;
summary.text = custom_summary;
cx.emit(ContextEvent::SummaryChanged);
}
}
#[derive(Debug, Default)]
@@ -2436,7 +2477,7 @@ mod tests {
use crate::{
assistant_panel, prompt_library,
slash_command::{active_command, file_command},
MessageId,
FakeCompletionProvider, MessageId,
};
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
use fs::FakeFs;
@@ -2458,8 +2499,7 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -2591,8 +2631,7 @@ mod tests {
fn test_message_splitting(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
FakeCompletionProvider::setup_test(cx);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -2685,8 +2724,7 @@ mod tests {
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
language_model::LanguageModelRegistry::test(cx);
completion::LanguageModelCompletionProvider::test(cx);
FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store);
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -2771,8 +2809,7 @@ mod tests {
async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(FakeCompletionProvider::setup_test);
cx.update(Project::init_settings);
cx.update(assistant_panel::init);
let fs = FakeFs::new(cx.background_executor.clone());
@@ -2897,11 +2934,7 @@ mod tests {
cx.update(prompt_library::init);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let fake_model = fake_provider.test_model();
let fake_provider = cx.update(FakeCompletionProvider::setup_test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
@@ -2967,8 +3000,8 @@ mod tests {
});
// Simulate the LLM completion
fake_model.send_last_completion_chunk(llm_response.to_string());
fake_model.finish_last_completion();
fake_provider.send_last_completion_chunk(llm_response.to_string());
fake_provider.finish_last_completion();
// Wait for the completion to be processed
cx.run_until_parked();
@@ -2996,12 +3029,60 @@ mod tests {
}
}
#[test]
fn test_parse_edit_operations() {
let operations = indoc! {r#"
Here are the operations to make all fields of the Canvas struct private:
<operations>
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
</operations>
"#};
let parsed_operations = Context::parse_edit_operations(operations);
assert_eq!(
parsed_operations,
vec![
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub pixels".to_string(),
description: "Remove pub keyword from pixels field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub size".to_string(),
description: "Remove pub keyword from size field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub stride".to_string(),
description: "Remove pub keyword from stride field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub format".to_string(),
description: "Remove pub keyword from format field".to_string(),
},
},
]
);
}
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(FakeCompletionProvider::setup_test);
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@@ -3077,9 +3158,7 @@ mod tests {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.update(FakeCompletionProvider::setup_test);
cx.update(assistant_panel::init);
let slash_commands = cx.update(SlashCommandRegistry::default_global);
slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);

File diff suppressed because it is too large Load Diff

View File

@@ -1,148 +1,82 @@
use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
use fs::Fs;
use gpui::SharedString;
use language_model::LanguageModelRegistry;
use settings::update_settings_file;
use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
#[derive(IntoElement)]
pub struct ModelSelector<T: PopoverTrigger> {
handle: Option<PopoverMenuHandle<ContextMenu>>,
pub struct ModelSelector {
handle: PopoverMenuHandle<ContextMenu>,
fs: Arc<dyn Fs>,
trigger: T,
info_text: Option<SharedString>,
}
impl<T: PopoverTrigger> ModelSelector<T> {
pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
ModelSelector {
handle: None,
fs,
trigger,
info_text: None,
}
}
pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
self.handle = Some(handle);
self
}
pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
self.info_text = Some(text.into());
self
impl ModelSelector {
pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
ModelSelector { handle, fs }
}
}
impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
fn render(self, _: &mut WindowContext) -> impl IntoElement {
let mut menu = PopoverMenu::new("model-switcher");
if let Some(handle) = self.handle {
menu = menu.with_handle(handle);
}
let info_text = self.info_text.clone();
menu.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
if let Some(info_text) = info_text.clone() {
menu = menu
.custom_row(move |_cx| {
Label::new(info_text.clone())
.color(Color::Muted)
.into_any_element()
})
.separator();
}
for (index, provider) in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.enumerate()
{
if index > 0 {
menu = menu.separator();
}
menu = menu.header(provider.name().0);
let available_models = provider.provided_models(cx);
if available_models.is_empty() {
impl RenderOnce for ModelSelector {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
PopoverMenu::new("model-switcher")
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) {
menu = menu.custom_entry(
{
move |_| {
h_flex()
.w_full()
.gap_1()
.child(Icon::new(IconName::Settings))
.child(Label::new("Configure"))
.into_any()
}
},
{
let provider = provider.id();
move |cx| {
LanguageModelCompletionProvider::global(cx).update(
cx,
|completion_provider, cx| {
completion_provider
.set_active_provider(provider.clone(), cx)
},
);
}
},
);
}
let selected_model = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.id());
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
.active_provider()
.map(|m| m.id());
for available_model in available_models {
menu = menu.custom_entry(
{
let id = available_model.id();
let provider_id = available_model.provider_id();
let model_name = available_model.name().0.clone();
let selected_model = selected_model.clone();
let selected_provider = selected_provider.clone();
move |_| {
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.when(
selected_model.as_ref() == Some(&id)
&& selected_provider.as_ref() == Some(&provider_id),
|this| this.child(Icon::new(IconName::Check)),
)
.into_any()
}
let model = model.clone();
move |_| Label::new(model.display_name()).into_any_element()
},
{
let fs = self.fs.clone();
let model = available_model.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| settings.set_model(model),
move |settings| settings.set_model(model),
);
}
},
);
}
}
menu
menu
})
.into()
})
.into()
})
.trigger(self.trigger)
.attach(gpui::AnchorCorner::BottomLeft)
.trigger(
ButtonLike::new("active-model")
.style(ButtonStyle::Subtle)
.child(
h_flex()
.w_full()
.gap_0p5()
.child(
div()
.overflow_x_hidden()
.flex_grow()
.whitespace_nowrap()
.child(
Label::new(
CompletionProvider::global(cx).model().display_name(),
)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
)
.tooltip(move |cx| {
Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
}),
)
.attach(gpui::AnchorCorner::BottomLeft)
}
}

View File

@@ -1,6 +1,6 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
LanguageModelCompletionProvider,
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use anyhow::{anyhow, Result};
use assets::Assets;
@@ -19,7 +19,6 @@ use gpui::{
};
use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use parking_lot::RwLock;
use picker::{Picker, PickerDelegate};
use rope::Rope;
@@ -636,9 +635,9 @@ impl PromptLibrary {
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let provider = LanguageModelCompletionProvider::read_global(cx);
let provider = CompletionProvider::global(cx);
let initial_prompt = action.prompt.clone();
if provider.is_authenticated(cx) {
if provider.is_authenticated() {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
})
@@ -736,8 +735,11 @@ impl PromptLibrary {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(
let provider = CompletionProvider::global(cx);
let model = provider.model();
provider.count_tokens(
LanguageModelRequest {
model,
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: body.to_string(),
@@ -749,7 +751,6 @@ impl PromptLibrary {
)
})?
.await?;
this.update(&mut cx, |this, cx| {
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
prompt_editor.token_count = Some(token_count);
@@ -804,7 +805,7 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
let current_model = CompletionProvider::global(cx).model();
let settings = ThemeSettings::get_global(cx);
Some(
@@ -915,11 +916,7 @@ impl PromptLibrary {
format!(
"Model: {}",
current_model
.as_ref()
.map(|model| model
.name()
.0)
.unwrap_or_default()
.display_name()
),
cx,
)

View File

@@ -6,7 +6,8 @@ pub fn generate_content_prompt(
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
) -> String {
_project_name: Option<String>,
) -> anyhow::Result<String> {
let mut prompt = String::new();
let content_type = match language_name {
@@ -14,16 +15,14 @@ pub fn generate_content_prompt(
writeln!(
prompt,
"Here's a file of text that I'm going to ask you to make an edit to."
)
.unwrap();
)?;
"text"
}
Some(language_name) => {
writeln!(
prompt,
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
)
.unwrap();
)?;
"code"
}
};
@@ -71,7 +70,7 @@ pub fn generate_content_prompt(
write!(prompt, "</document>\n\n").unwrap();
if is_truncated {
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
}
if range.is_empty() {
@@ -108,7 +107,7 @@ pub fn generate_content_prompt(
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
}
prompt
Ok(prompt)
}
pub fn generate_terminal_assistant_prompt(

View File

@@ -33,7 +33,7 @@ impl DiagnosticsSlashCommand {
if query.is_empty() {
let workspace = workspace.read(cx);
let entries = workspace.recent_navigation_history(Some(10), cx);
let path_prefix: Arc<str> = Arc::default();
let path_prefix: Arc<str> = "".into();
Task::ready(
entries
.into_iter()
@@ -284,7 +284,7 @@ fn collect_diagnostics(
PathBuf::try_from(path)
.ok()
.and_then(|path| {
project.read(cx).worktrees(cx).find_map(|worktree| {
project.read(cx).worktrees().find_map(|worktree| {
let worktree = worktree.read(cx);
let worktree_root_path = Path::new(worktree.root_name());
let relative_path = path.strip_prefix(worktree_root_path).ok()?;

View File

@@ -24,7 +24,7 @@ impl DocsSlashCommand {
pub const NAME: &'static str = "docs";
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = project.read(cx).worktrees().next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
@@ -219,7 +219,7 @@ impl SlashCommand for DocsSlashCommand {
if index {
// We don't need to hold onto this task, as the `IndexedDocsStore` will hold it
// until it completes.
drop(store.clone().index(package.as_str().into()));
let _ = store.clone().index(package.as_str().into());
}
let items = store.search(package).await;

View File

@@ -10,7 +10,7 @@ use assistant_slash_command::{
use futures::AsyncReadExt;
use gpui::{AppContext, Task, WeakView};
use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use http::{AsyncBody, HttpClient, HttpClientWithUrl};
use language::LspAdapterDelegate;
use ui::prelude::*;
use workspace::Workspace;

View File

@@ -29,7 +29,7 @@ impl FileSlashCommand {
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
let entries = workspace.recent_navigation_history(Some(10), cx);
let path_prefix: Arc<str> = Arc::default();
let path_prefix: Arc<str> = "".into();
Task::ready(
entries
.into_iter()
@@ -188,7 +188,7 @@ fn collect_files(
let project_handle = project.downgrade();
let snapshots = project
.read(cx)
.worktrees(cx)
.worktrees()
.map(|worktree| worktree.read(cx).snapshot())
.collect::<Vec<_>>();
cx.spawn(|mut cx| async move {

View File

@@ -75,7 +75,7 @@ impl ProjectSlashCommand {
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = project.read(cx).worktrees().next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {

View File

@@ -1,6 +1,7 @@
use crate::{
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@@ -12,12 +13,11 @@ use editor::{
use fs::Fs;
use futures::{channel::mpsc, SinkExt, StreamExt};
use gpui::{
AppContext, Context, EventEmitter, FocusHandle, FocusableView, Global, Model, ModelContext,
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
AppContext, Context, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global,
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use settings::Settings;
use settings::{update_settings_file, Settings};
use std::{
cmp,
sync::Arc,
@@ -26,7 +26,7 @@ use std::{
use terminal::Terminal;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip};
use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
use util::ResultExt;
use workspace::{notifications::NotificationId, Toast, Workspace};
@@ -214,6 +214,8 @@ impl TerminalInlineAssistant {
) -> Result<LanguageModelRequest> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let model = CompletionProvider::global(cx).model();
let shell = std::env::var("SHELL").ok();
let working_directory = assist
.terminal
@@ -265,6 +267,7 @@ impl TerminalInlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: Vec::new(),
temperature: 1.0,
@@ -447,19 +450,22 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
impl Render for PromptEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let fs = self.fs.clone();
let buttons = match &self.codegen.read(cx).status {
CodegenStatus::Idle => {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("start", IconName::SparkleAlt)
IconButton::new("start", IconName::Sparkle)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| Tooltip::for_action("Generate", &menu::Confirm, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
@@ -470,14 +476,15 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::text("Cancel Assist", cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("stop", IconName::Stop)
.icon_color(Color::Error)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| {
Tooltip::with_meta(
"Interrupt Generation",
@@ -495,7 +502,7 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
@@ -503,7 +510,8 @@ impl Render for PromptEditor {
if self.edited_since_done {
IconButton::new("restart", IconName::RotateCw)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.size(ButtonSize::None)
.tooltip(|cx| {
Tooltip::with_meta(
"Restart Generation",
@@ -518,7 +526,7 @@ impl Render for PromptEditor {
} else {
IconButton::new("confirm", IconName::Play)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| {
Tooltip::for_action("Execute generated command", &menu::Confirm, cx)
})
@@ -546,27 +554,59 @@ impl Render for PromptEditor {
.w_12()
.justify_center()
.gap_2()
.child(ModelSelector::new(
self.fs.clone(),
IconButton::new("context", IconName::Settings)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
cx,
)
}),
))
.child(
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
}
},
{
let fs = fs.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
);
}
},
);
}
menu
})
.into()
})
.trigger(
IconButton::new("context", IconName::Settings)
.size(ButtonSize::None)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
),
None,
"Click to Change Model",
cx,
)
}),
)
.anchor(gpui::AnchorCorner::BottomRight),
)
.children(
if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
let error_message = SharedString::from(error.to_string());
@@ -708,9 +748,7 @@ impl PromptEditor {
})??;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@@ -840,7 +878,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = CompletionProvider::global(cx).model();
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@@ -906,11 +944,14 @@ impl PromptEditor {
},
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_weight: FontWeight::NORMAL,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.editor,
@@ -944,7 +985,7 @@ impl TerminalTransaction {
}
pub fn push(&mut self, hunk: String, cx: &mut AppContext) {
// Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
// Ensure that the assistant cannot accidently execute commands that are streamed into the terminal
let input = hunk.replace(CARRIAGE_RETURN, " ");
self.terminal
.update(cx, |terminal, _| terminal.input(input));
@@ -986,12 +1027,8 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let telemetry = self.telemetry.clone();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).stream_completion(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;

View File

@@ -1,8 +1,7 @@
[package]
name = "remote"
description = "Client-side subsystem for remote editing"
edition = "2021"
name = "assistant_tooling"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
@@ -10,29 +9,25 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/remote.rs"
doctest = false
[features]
default = []
test-support = ["fs/test-support"]
path = "src/assistant_tooling.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
log.workspace = true
parking_lot.workspace = true
prost.workspace = true
rpc.workspace = true
project.workspace = true
repair_json.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
tempfile.workspace = true
sum_tree.workspace = true
ui.workspace = true
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
fs = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -0,0 +1,85 @@
# Assistant Tooling
Bringing Language Model tool calling to GPUI.
This unlocks:
- **Structured Extraction** of model responses
- **Validation** of model inputs
- **Execution** of chosen tools
## Overview
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
>
> **Assistant**: "Sure, I can help with that. Let me see what I can find."
>
> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
>
> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
>
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`.
## Using the Tool Registry
```rust
let mut tool_registry = ToolRegistry::new();
tool_registry
.register(WeatherTool { api_client },
})
.unwrap(); // You can only register one tool per name
let completion = cx.update(|cx| {
CompletionProvider::get(cx).complete(
model_name,
messages,
Vec::new(),
1.0,
// The definitions get passed directly to OpenAI when you want
// the model to be able to call your tool
tool_registry.definitions(),
)
});
let mut stream = completion?.await?;
let mut message = AssistantMessage::new();
while let Some(delta) = stream.next().await {
// As messages stream in, you'll get both assistant content
if let Some(content) = &delta.content {
message
.body
.update(cx, |message, cx| message.append(&content, cx));
}
// And tool calls!
for tool_call_delta in delta.tool_calls {
let index = tool_call_delta.index as usize;
if index >= message.tool_calls.len() {
message.tool_calls.resize_with(index + 1, Default::default);
}
let tool_call = &mut message.tool_calls[index];
// Build up an ID
if let Some(id) = &tool_call_delta.id {
tool_call.id.push_str(id);
}
tool_registry.update_tool_call(
tool_call,
tool_call_delta.name.as_deref(),
tool_call_delta.arguments.as_deref(),
cx,
);
}
}
```
Once the stream of tokens is complete, you can exexute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task<Result<()>>`.
As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`.

View File

@@ -0,0 +1,13 @@
mod attachment_registry;
mod project_context;
mod tool_registry;
pub use attachment_registry::{
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
UserAttachment,
};
pub use project_context::ProjectContext;
pub use tool_registry::{
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition,
ToolRegistry, ToolView,
};

View File

@@ -0,0 +1,234 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
use gpui::{AnyView, Render, Task, View, WindowContext};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
any::TypeId,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
use util::ResultExt as _;
pub struct AttachmentRegistry {
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
}
pub trait AttachmentOutput {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
}
pub trait LanguageModelAttachment {
type Output: DeserializeOwned + Serialize + 'static;
type View: Render + AttachmentOutput;
fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub view: AnyView,
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
}
#[derive(Serialize, Deserialize)]
pub struct SavedUserAttachment {
name: Arc<str>,
serialized_output: Result<Box<RawValue>, String>,
}
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct RegisteredAttachment {
name: Arc<str>,
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
}
impl AttachmentRegistry {
pub fn new() -> Self {
Self {
registered_attachments: HashMap::default(),
}
}
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
let attachment = Arc::new(attachment);
let call = Box::new({
let attachment = attachment.clone();
move |cx: &mut WindowContext| {
let result = attachment.run(cx);
let attachment = attachment.clone();
cx.spawn(move |mut cx| async move {
let result: Result<A::Output> = result.await;
let serialized_output =
result
.as_ref()
.map_err(ToString::to_string)
.and_then(|output| {
Ok(RawValue::from_string(
serde_json::to_string(output).map_err(|e| e.to_string())?,
)
.unwrap())
});
let view = cx.update(|cx| attachment.view(result, cx))?;
Ok(UserAttachment {
name: attachment.name(),
view: view.into(),
generate_fn: generate::<A>,
serialized_output,
})
})
}
});
let deserialize = Box::new({
let attachment = attachment.clone();
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
let serialized_output = saved_attachment.serialized_output.clone();
let output = match &serialized_output {
Ok(serialized_output) => {
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
}
Err(error) => Err(anyhow!("{error}")),
};
let view = attachment.view(output, cx).into();
Ok(UserAttachment {
name: saved_attachment.name.clone(),
view,
serialized_output,
generate_fn: generate::<A>,
})
}
});
self.registered_attachments.insert(
TypeId::of::<A>(),
RegisteredAttachment {
name: attachment.name(),
call,
deserialize,
enabled: AtomicBool::new(true),
},
);
return;
fn generate<T: LanguageModelAttachment>(
view: AnyView,
project: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
view.downcast::<T::View>()
.unwrap()
.update(cx, |view, cx| T::View::generate(view, project, cx))
}
}
pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
&self,
is_enabled: bool,
) {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.store(is_enabled, SeqCst);
}
}
pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
attachment.enabled.load(SeqCst)
} else {
false
}
}
pub fn call<A: LanguageModelAttachment + 'static>(
&self,
cx: &mut WindowContext,
) -> Task<Result<UserAttachment>> {
let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
return Task::ready(Err(anyhow!("no attachment tool")));
};
(attachment.call)(cx)
}
pub fn call_all_attachment_tools(
self: Arc<Self>,
cx: &mut WindowContext<'_>,
) -> Task<Result<Vec<UserAttachment>>> {
let this = self.clone();
cx.spawn(|mut cx| async move {
let attachment_tasks = cx.update(|cx| {
let mut tasks = Vec::new();
for attachment in this
.registered_attachments
.values()
.filter(|attachment| attachment.enabled.load(SeqCst))
{
tasks.push((attachment.call)(cx))
}
tasks
})?;
let attachments = join_all(attachment_tasks.into_iter()).await;
Ok(attachments
.into_iter()
.filter_map(|attachment| attachment.log_err())
.collect())
})
}
pub fn serialize_user_attachment(
&self,
user_attachment: &UserAttachment,
) -> SavedUserAttachment {
SavedUserAttachment {
name: user_attachment.name.clone(),
serialized_output: user_attachment.serialized_output.clone(),
}
}
pub fn deserialize_user_attachment(
&self,
saved_user_attachment: SavedUserAttachment,
cx: &mut WindowContext,
) -> Result<UserAttachment> {
if let Some(registered_attachment) = self
.registered_attachments
.values()
.find(|attachment| attachment.name == saved_user_attachment.name)
{
(registered_attachment.deserialize)(&saved_user_attachment, cx)
} else {
Err(anyhow!(
"no attachment tool for name {}",
saved_user_attachment.name
))
}
}
}
impl UserAttachment {
pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
let result = (self.generate_fn)(self.view.clone(), output, cx);
if result.is_empty() {
None
} else {
Some(result)
}
}
}

View File

@@ -0,0 +1,296 @@
use anyhow::{anyhow, Result};
use gpui::{AppContext, Model, Task, WeakModel};
use project::{Fs, Project, ProjectPath, Worktree};
use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
use sum_tree::TreeMap;
pub struct ProjectContext {
files: TreeMap<ProjectPath, PathState>,
project: WeakModel<Project>,
fs: Arc<dyn Fs>,
}
#[derive(Debug, Clone)]
enum PathState {
PathOnly,
EntireFile,
Excerpts { ranges: Vec<Range<usize>> },
}
impl ProjectContext {
pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
Self {
files: TreeMap::default(),
fs,
project,
}
}
pub fn add_path(&mut self, project_path: ProjectPath) {
if self.files.get(&project_path).is_none() {
self.files.insert(project_path, PathState::PathOnly);
}
}
pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
let previous_state = self
.files
.get(&project_path)
.unwrap_or(&PathState::PathOnly);
let mut ranges = match previous_state {
PathState::EntireFile => return,
PathState::PathOnly => Vec::new(),
PathState::Excerpts { ranges } => ranges.to_vec(),
};
for new_range in new_ranges {
let ix = ranges.binary_search_by(|probe| {
if probe.end < new_range.start {
Ordering::Less
} else if probe.start > new_range.end {
Ordering::Greater
} else {
Ordering::Equal
}
});
match ix {
Ok(mut ix) => {
let existing = &mut ranges[ix];
existing.start = existing.start.min(new_range.start);
existing.end = existing.end.max(new_range.end);
while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
ranges.remove(ix + 1);
}
while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
ranges.remove(ix - 1);
ix -= 1;
}
}
Err(ix) => {
ranges.insert(ix, new_range.clone());
}
}
}
self.files
.insert(project_path, PathState::Excerpts { ranges });
}
pub fn add_file(&mut self, project_path: ProjectPath) {
self.files.insert(project_path, PathState::EntireFile);
}
pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
let project = self
.project
.upgrade()
.ok_or_else(|| anyhow!("project dropped"));
let files = self.files.clone();
let fs = self.fs.clone();
cx.spawn(|cx| async move {
let project = project?;
let mut result = "project structure:\n".to_string();
let mut last_worktree: Option<Model<Worktree>> = None;
for (project_path, path_state) in files.iter() {
if let Some(worktree) = &last_worktree {
if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
last_worktree = None;
}
}
let worktree;
if let Some(last_worktree) = &last_worktree {
worktree = last_worktree.clone();
} else if let Some(tree) = project.read_with(&cx, |project, cx| {
project.worktree_for_id(project_path.worktree_id, cx)
})? {
worktree = tree;
last_worktree = Some(worktree.clone());
let worktree_name =
worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
writeln!(&mut result, "# {}", worktree_name).unwrap();
} else {
continue;
}
let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
let path = &project_path.path;
writeln!(&mut result, "## {}", path.display()).unwrap();
match path_state {
PathState::PathOnly => {}
PathState::EntireFile => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
}
PathState::Excerpts { ranges } => {
let text = fs.load(&worktree_abs_path.join(&path)).await?;
writeln!(&mut result, "~~~").unwrap();
// Assumption: ranges are in order, not overlapping
let mut prev_range_end = 0;
for range in ranges {
if range.start > prev_range_end {
writeln!(&mut result, "...").unwrap();
prev_range_end = range.end;
}
let mut start = range.start;
let mut end = range.end.min(text.len());
while !text.is_char_boundary(start) {
start += 1;
}
while !text.is_char_boundary(end) {
end -= 1;
}
result.push_str(&text[start..end]);
if !result.ends_with('\n') {
result.push('\n');
}
}
if prev_range_end < text.len() {
writeln!(&mut result, "...").unwrap();
}
writeln!(&mut result, "~~~").unwrap();
}
}
}
Ok(result)
})
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::TestAppContext;
use project::FakeFs;
use serde_json::json;
use settings::SettingsStore;
use unindent::Unindent as _;
#[gpui::test]
async fn test_system_message_generation(cx: &mut TestAppContext) {
init_test(cx);
let file_3_contents = r#"
fn test1() {}
fn test2() {}
fn test3() {}
"#
.unindent();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/code",
json!({
"root1": {
"lib": {
"file1.rs": "mod example;",
"file2.rs": "",
},
"test": {
"file3.rs": file_3_contents,
}
},
"root2": {
"src": {
"main.rs": ""
}
}
}),
)
.await;
let project = Project::test(
fs.clone(),
["/code/root1".as_ref(), "/code/root2".as_ref()],
cx,
)
.await;
let worktree_ids = project.read_with(cx, |project, cx| {
project
.worktrees()
.map(|worktree| worktree.read(cx).id())
.collect::<Vec<_>>()
});
let mut ax = ProjectContext::new(project.downgrade(), fs);
ax.add_file(ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("lib/file1.rs").into(),
});
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
"#
.unindent(),
message
);
ax.add_excerpts(
ProjectPath {
worktree_id: worktree_ids[0],
path: Path::new("test/file3.rs").into(),
},
&[
file_3_contents.find("fn test2").unwrap()
..file_3_contents.find("fn test3").unwrap(),
],
);
let message = cx
.update(|cx| ax.generate_system_message(cx))
.await
.unwrap();
assert_eq!(
r#"
project structure:
# root1
## lib/file1.rs
~~~
mod example;
~~~
## test/file3.rs
~~~
...
fn test2() {}
...
~~~
"#
.unindent(),
message
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
});
}
}

View File

@@ -0,0 +1,526 @@
use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
use repair_json::repair;
use schemars::{schema::RootSchema, schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
mem,
sync::atomic::{AtomicBool, Ordering::SeqCst},
};
use ui::ViewContext;
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
state: ToolFunctionCallState,
}
#[derive(Default)]
enum ToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
KnownTool(Box<dyn InternalToolView>),
ExecutedTool(Box<dyn InternalToolView>),
}
trait InternalToolView {
fn view(&self) -> AnyView;
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
}
#[derive(Default, Serialize, Deserialize)]
pub struct SavedToolFunctionCall {
id: String,
name: String,
arguments: String,
state: SavedToolFunctionCallState,
}
#[derive(Default, Serialize, Deserialize)]
enum SavedToolFunctionCallState {
#[default]
Initializing,
NoSuchTool,
KnownTool,
ExecutedTool(Box<RawValue>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct ToolFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}
pub trait LanguageModelTool {
type View: ToolView;
/// Returns the name of the tool.
///
/// This name is exposed to the language model to allow the model to pick
/// which tools to use. As this name is used to identify the tool within a
/// tool registry, it should be unique.
fn name(&self) -> String;
/// Returns the description of the tool.
///
/// This can be used to _prompt_ the model as to what the tool does.
fn description(&self) -> String;
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
fn definition(&self) -> ToolFunctionDefinition {
let root_schema = schema_for!(<Self::View as ToolView>::Input);
ToolFunctionDefinition {
name: self.name(),
description: self.description(),
parameters: root_schema,
}
}
/// A view of the output of running the tool, for displaying to the user.
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
}
pub trait ToolView: Render {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
type SerializedState: DeserializeOwned + Serialize;
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
fn deserialize(
&mut self,
output: Self::SerializedState,
cx: &mut ViewContext<Self>,
) -> Result<()>;
}
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
definition: ToolFunctionDefinition,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
registered_tools: HashMap::new(),
}
}
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
tool.enabled.store(is_enabled, SeqCst);
return;
}
}
}
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
for tool in self.registered_tools.values() {
if tool.type_id == TypeId::of::<T>() {
return tool.enabled.load(SeqCst);
}
}
false
}
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
self.registered_tools
.values()
.filter(|tool| tool.enabled.load(SeqCst))
.map(|tool| tool.definition.clone())
.collect()
}
pub fn update_tool_call(
&self,
call: &mut ToolFunctionCall,
name: Option<&str>,
arguments: Option<&str>,
cx: &mut WindowContext,
) {
if let Some(name) = name {
call.name.push_str(name);
}
if let Some(arguments) = arguments {
if call.arguments.is_empty() {
if let Some(tool) = self.registered_tools.get(&call.name) {
let view = (tool.build_view)(cx);
call.state = ToolFunctionCallState::KnownTool(view);
} else {
call.state = ToolFunctionCallState::NoSuchTool;
}
}
call.arguments.push_str(arguments);
if let ToolFunctionCallState::KnownTool(view) = &call.state {
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
view.try_set_input(&repaired_arguments, cx)
}
}
}
}
pub fn execute_tool_call(
&self,
tool_call: &mut ToolFunctionCall,
cx: &mut WindowContext,
) -> Option<Task<Result<()>>> {
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
let task = view.execute(cx);
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
Some(task)
} else {
None
}
}
pub fn render_tool_call(
&self,
tool_call: &ToolFunctionCall,
_cx: &mut WindowContext,
) -> Option<AnyElement> {
match &tool_call.state {
ToolFunctionCallState::NoSuchTool => {
Some(ui::Label::new("No such tool").into_any_element())
}
ToolFunctionCallState::Initializing => None,
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
Some(view.view().into_any_element())
}
}
}
pub fn content_for_tool_call(
&self,
tool_call: &ToolFunctionCall,
project_context: &mut ProjectContext,
cx: &mut WindowContext,
) -> String {
match &tool_call.state {
ToolFunctionCallState::Initializing => String::new(),
ToolFunctionCallState::NoSuchTool => {
format!("No such tool: {}", tool_call.name)
}
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
view.generate(project_context, cx)
}
}
}
pub fn serialize_tool_call(
&self,
call: &ToolFunctionCall,
cx: &mut WindowContext,
) -> Result<SavedToolFunctionCall> {
Ok(SavedToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
state: match &call.state {
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
ToolFunctionCallState::ExecutedTool(view) => {
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
}
},
})
}
pub fn deserialize_tool_call(
&self,
call: &SavedToolFunctionCall,
cx: &mut WindowContext,
) -> Result<ToolFunctionCall> {
let Some(tool) = self.registered_tools.get(&call.name) else {
return Err(anyhow!("no such tool {}", call.name));
};
Ok(ToolFunctionCall {
id: call.id.clone(),
name: call.name.clone(),
arguments: call.arguments.clone(),
state: match &call.state {
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
SavedToolFunctionCallState::KnownTool => {
log::error!("Deserialized tool that had not executed");
let view = (tool.build_view)(cx);
view.try_set_input(&call.arguments, cx);
ToolFunctionCallState::KnownTool(view)
}
SavedToolFunctionCallState::ExecutedTool(output) => {
let view = (tool.build_view)(cx);
view.try_set_input(&call.arguments, cx);
view.deserialize_output(output, cx)?;
ToolFunctionCallState::ExecutedTool(view)
}
},
})
}
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
let name = tool.name();
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
};
let previous = self.registered_tools.insert(name.clone(), registered_tool);
if previous.is_some() {
return Err(anyhow!("already registered a tool with name {}", name));
}
return Ok(());
}
}
impl<T: ToolView> InternalToolView for View<T> {
fn view(&self) -> AnyView {
self.clone().into()
}
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
self.update(cx, |view, cx| view.generate(project, cx))
}
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
self.update(cx, |view, cx| {
view.set_input(input, cx);
cx.notify();
});
}
}
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
self.update(cx, |view, cx| view.execute(cx))
}
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
let output = self.update(cx, |view, cx| view.serialize(cx));
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
}
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
self.update(cx, |view, cx| view.deserialize(state, cx))?;
Ok(())
}
}
impl Display for ToolFunctionDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let schema = serde_json::to_string(&self.parameters).ok();
let schema = schema.unwrap_or("None".to_string());
write!(f, "Name: {}:\n", self.name)?;
write!(f, "Description: {}\n", self.description)?;
write!(f, "Parameters: {}", schema)
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{div, prelude::*, Render, TestAppContext};
use gpui::{EmptyView, View};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize, Serialize, JsonSchema)]
struct WeatherQuery {
location: String,
unit: String,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct WeatherResult {
location: String,
temperature: f64,
unit: String,
}
struct WeatherView {
input: Option<WeatherQuery>,
result: Option<WeatherResult>,
// Fake API call
current_weather: WeatherResult,
}
#[derive(Clone, Serialize)]
struct WeatherTool {
current_weather: WeatherResult,
}
impl WeatherView {
fn new(current_weather: WeatherResult) -> Self {
Self {
input: None,
result: None,
current_weather,
}
}
}
impl Render for WeatherView {
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
match self.result {
Some(ref result) => div()
.child(format!("temperature: {}", result.temperature))
.into_any_element(),
None => div().child("Calculating weather...").into_any_element(),
}
}
}
impl ToolView for WeatherView {
type Input = WeatherQuery;
type SerializedState = WeatherResult;
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
serde_json::to_string(&self.result).unwrap()
}
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
self.input = Some(input);
cx.notify();
}
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
let input = self.input.as_ref().unwrap();
let _location = input.location.clone();
let _unit = input.unit.clone();
let weather = self.current_weather.clone();
self.result = Some(weather);
Task::ready(Ok(()))
}
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
self.current_weather.clone()
}
fn deserialize(
&mut self,
output: Self::SerializedState,
_cx: &mut ViewContext<Self>,
) -> Result<()> {
self.current_weather = output;
Ok(())
}
}
impl LanguageModelTool for WeatherTool {
type View = WeatherView;
fn name(&self) -> String {
"get_current_weather".to_string()
}
fn description(&self) -> String {
"Fetches the current weather for a given location.".to_string()
}
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
}
}
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
let mut registry = ToolRegistry::new();
registry
.register(WeatherTool {
current_weather: WeatherResult {
location: "San Francisco".to_string(),
temperature: 21.0,
unit: "Celsius".to_string(),
},
})
.unwrap();
let definitions = registry.definitions();
assert_eq!(
definitions,
[ToolFunctionDefinition {
name: "get_current_weather".to_string(),
description: "Fetches the current weather for a given location.".to_string(),
parameters: serde_json::from_value(json!({
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "WeatherQuery",
"type": "object",
"properties": {
"location": {
"type": "string"
},
"unit": {
"type": "string"
}
},
"required": ["location", "unit"]
}))
.unwrap(),
}]
);
let mut call = ToolFunctionCall {
id: "the-id".to_string(),
name: "get_cur".to_string(),
..Default::default()
};
let task = cx.update(|cx| {
registry.update_tool_call(
&mut call,
Some("rent_weather"),
Some(r#"{"location": "San Francisco","#),
cx,
);
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
registry.execute_tool_call(&mut call, cx).unwrap()
});
task.await.unwrap();
match &call.state {
ToolFunctionCallState::ExecutedTool(_view) => {}
_ => panic!(),
}
}
}

View File

@@ -18,12 +18,11 @@ client.workspace = true
db.workspace = true
editor.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
isahc.workspace = true
log.workspace = true
markdown_preview.workspace = true
menu.workspace = true
paths.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true

View File

@@ -20,7 +20,7 @@ use smol::{fs, io::AsyncReadExt};
use settings::{Settings, SettingsSources, SettingsStore};
use smol::{fs::File, process::Command};
use http_client::{HttpClient, HttpClientWithUrl};
use http::{HttpClient, HttpClientWithUrl};
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
use std::{
env::{
@@ -28,7 +28,7 @@ use std::{
consts::{ARCH, OS},
},
ffi::OsString,
path::{Path, PathBuf},
path::PathBuf,
sync::Arc,
time::Duration,
};
@@ -359,6 +359,7 @@ impl AutoUpdater {
return;
}
self.status = AutoUpdateStatus::Checking;
cx.notify();
self.pending_poll = Some(cx.spawn(|this, mut cx| async move {
@@ -384,65 +385,29 @@ impl AutoUpdater {
cx.notify();
}
pub async fn get_latest_remote_server_release(
os: &str,
arch: &str,
mut release_channel: ReleaseChannel,
cx: &mut AsyncAppContext,
) -> Result<PathBuf> {
let this = cx.update(|cx| {
cx.default_global::<GlobalAutoUpdate>()
.0
.clone()
.ok_or_else(|| anyhow!("auto-update not initialized"))
})??;
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, current_version) = this.read_with(&cx, |this, _| {
(this.http_client.clone(), this.current_version)
})?;
if release_channel == ReleaseChannel::Dev {
release_channel = ReleaseChannel::Nightly;
}
let asset = match OS {
"linux" => format!("zed-linux-{}.tar.gz", ARCH),
"macos" => "Zed.dmg".into(),
_ => return Err(anyhow!("auto-update not supported for OS {:?}", OS)),
};
let release = Self::get_latest_release(
&this,
"zed-remote-server",
os,
arch,
Some(release_channel),
cx,
)
.await?;
let servers_dir = paths::remote_servers_dir();
let channel_dir = servers_dir.join(release_channel.dev_name());
let platform_dir = channel_dir.join(format!("{}-{}", os, arch));
let version_path = platform_dir.join(format!("{}.gz", release.version));
smol::fs::create_dir_all(&platform_dir).await.ok();
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
if smol::fs::metadata(&version_path).await.is_err() {
log::info!("downloading zed-remote-server {os} {arch}");
download_remote_server_binary(&version_path, release, client, cx).await?;
}
Ok(version_path)
}
async fn get_latest_release(
this: &Model<Self>,
asset: &str,
os: &str,
arch: &str,
release_channel: Option<ReleaseChannel>,
cx: &mut AsyncAppContext,
) -> Result<JsonRelease> {
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
let mut url_string = client.build_url(&format!(
"/api/releases/latest?asset={}&os={}&arch={}",
asset, os, arch
asset, OS, ARCH
));
if let Some(param) = release_channel.and_then(|c| c.release_query_param()) {
url_string += "&";
url_string += param;
}
cx.update(|cx| {
if let Some(param) = ReleaseChannel::try_global(cx)
.and_then(|release_channel| release_channel.release_query_param())
{
url_string += "&";
url_string += param;
}
})?;
let mut response = client.get(&url_string, Default::default(), true).await?;
@@ -453,34 +418,8 @@ impl AutoUpdater {
.await
.context("error reading release")?;
if !response.status().is_success() {
Err(anyhow!(
"failed to fetch release: {:?}",
String::from_utf8_lossy(&body),
))?;
}
serde_json::from_slice(body.as_slice()).with_context(|| {
format!(
"error deserializing release {:?}",
String::from_utf8_lossy(&body),
)
})
}
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, current_version, release_channel) = this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Checking;
cx.notify();
(
this.http_client.clone(),
this.current_version,
ReleaseChannel::try_global(cx),
)
})?;
let release =
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
let release: JsonRelease =
serde_json::from_slice(body.as_slice()).context("error deserializing release")?;
let should_download = match *RELEASE_CHANNEL {
ReleaseChannel::Nightly => cx
@@ -507,14 +446,7 @@ impl AutoUpdater {
let temp_dir = tempfile::Builder::new()
.prefix("zed-auto-update")
.tempdir()?;
let filename = match OS {
"macos" => Ok("Zed.dmg"),
"linux" => Ok("zed.tar.gz"),
_ => Err(anyhow!("not supported: {:?}", OS)),
}?;
let downloaded_asset = temp_dir.path().join(filename);
download_release(&downloaded_asset, release, client, &cx).await?;
let downloaded_asset = download_release(&temp_dir, release, &asset, client, &cx).await?;
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Installing;
@@ -568,38 +500,14 @@ impl AutoUpdater {
}
}
async fn download_remote_server_binary(
target_path: &PathBuf,
release: JsonRelease,
client: Arc<HttpClientWithUrl>,
cx: &AsyncAppContext,
) -> Result<()> {
let mut target_file = File::create(&target_path).await?;
let (installation_id, release_channel, telemetry) = cx.update(|cx| {
let installation_id = Client::global(cx).telemetry().installation_id();
let release_channel =
ReleaseChannel::try_global(cx).map(|release_channel| release_channel.display_name());
let telemetry = TelemetrySettings::get_global(cx).metrics;
(installation_id, release_channel, telemetry)
})?;
let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
installation_id,
release_channel,
telemetry,
})?);
let mut response = client.get(&release.url, request_body, true).await?;
smol::io::copy(response.body_mut(), &mut target_file).await?;
Ok(())
}
async fn download_release(
target_path: &Path,
temp_dir: &tempfile::TempDir,
release: JsonRelease,
target_filename: &str,
client: Arc<HttpClientWithUrl>,
cx: &AsyncAppContext,
) -> Result<()> {
) -> Result<PathBuf> {
let target_path = temp_dir.path().join(target_filename);
let mut target_file = File::create(&target_path).await?;
let (installation_id, release_channel, telemetry) = cx.update(|cx| {
@@ -621,7 +529,7 @@ async fn download_release(
smol::io::copy(response.body_mut(), &mut target_file).await?;
log::info!("downloaded update. path:{:?}", target_path);
Ok(())
Ok(target_path)
}
async fn install_release_linux(

View File

@@ -51,4 +51,4 @@ language = { workspace = true, features = ["test-support"] }
live_kit_client = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -493,7 +493,7 @@ impl Room {
// we leave the room and return an error.
if let Some(this) = this.upgrade() {
log::info!("reconnection failed, leaving room");
let _ = this.update(&mut cx, |this, cx| this.leave(cx))?.await?;
let _ = this.update(&mut cx, |this, cx| this.leave(cx))?;
}
Err(anyhow!(
"can't reconnect to room: client failed to re-establish connection"
@@ -526,7 +526,7 @@ impl Room {
rejoined_projects.push(proto::RejoinProject {
id: project_id,
worktrees: project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
proto::RejoinWorktree {
@@ -942,7 +942,7 @@ impl Room {
this.pending_room_update.take();
if this.should_leave() {
log::info!("room is empty, leaving");
let _ = this.leave(cx).detach();
let _ = this.leave(cx);
}
this.user_store.update(cx, |user_store, cx| {

View File

@@ -40,4 +40,4 @@ rpc = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -4,7 +4,7 @@ use super::*;
use client::{test::FakeServer, Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext, Context, Model, SemanticVersion, TestAppContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use rpc::proto::{self};
use settings::SettingsStore;

View File

@@ -129,7 +129,6 @@ fn main() -> Result<()> {
|| path.starts_with("http://")
|| path.starts_with("https://")
|| path.starts_with("file://")
|| path.starts_with("ssh://")
{
urls.push(path.to_string());
} else {

View File

@@ -18,7 +18,7 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
[dependencies]
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["async-std", "async-native-tls"] }
async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
collections.workspace = true
@@ -26,7 +26,7 @@ feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
lazy_static.workspace = true
log.workspace = true
once_cell.workspace = true
@@ -60,11 +60,12 @@ gpui = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
cocoa.workspace = true
isahc = { workspace = true, features = ["static-curl"] }
async-native-tls = { version = "0.5.0", features = ["vendored"] }

View File

@@ -7,9 +7,8 @@ pub mod user;
use anyhow::{anyhow, Context as _, Result};
use async_recursion::async_recursion;
use async_tungstenite::tungstenite::{
client::IntoClientRequest,
error::Error as WebsocketError,
http::{HeaderValue, Request, StatusCode},
http::{Request, StatusCode},
};
use clock::SystemClock;
use collections::HashMap;
@@ -21,7 +20,7 @@ use futures::{
use gpui::{
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use http::{HttpClient, HttpClientWithUrl};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use postage::watch;
@@ -234,9 +233,7 @@ pub enum EstablishConnectionError {
#[error("{0}")]
Other(#[from] anyhow::Error),
#[error("{0}")]
Http(#[from] http_client::Error),
#[error("{0}")]
InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
Http(#[from] http::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
@@ -1162,24 +1159,19 @@ impl Client {
.ok()
.unwrap_or_default();
let request = Request::builder()
.header("Authorization", credentials.authorization_header())
.header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
.header("x-zed-app-version", app_version)
.header(
"x-zed-release-channel",
release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
);
let http = self.http.clone();
let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel);
cx.background_executor().spawn(async move {
use HttpOrHttps::*;
#[derive(Debug)]
enum HttpOrHttps {
Http,
Https,
}
let mut rpc_url = rpc_url.await?;
let url_scheme = match rpc_url.scheme() {
"https" => Https,
"http" => Http,
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
};
let rpc_host = rpc_url
.host_str()
.zip(rpc_url.port_or_known_default())
@@ -1188,37 +1180,10 @@ impl Client {
log::info!("connected to rpc endpoint {}", rpc_url);
rpc_url
.set_scheme(match url_scheme {
Https => "wss",
Http => "ws",
})
.unwrap();
// We call `into_client_request` to let `tungstenite` construct the WebSocket request
// for us from the RPC URL.
//
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
let mut request = rpc_url.into_client_request()?;
// We then modify the request to add our desired headers.
let request_headers = request.headers_mut();
request_headers.insert(
"Authorization",
HeaderValue::from_str(&credentials.authorization_header())?,
);
request_headers.insert(
"x-zed-protocol-version",
HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
);
request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
request_headers.insert(
"x-zed-release-channel",
HeaderValue::from_str(&release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
);
match url_scheme {
Https => {
match rpc_url.scheme() {
"https" => {
rpc_url.set_scheme("wss").unwrap();
let request = request.uri(rpc_url.as_str()).body(())?;
let (stream, _) =
async_tungstenite::async_std::client_async_tls(request, stream).await?;
Ok(Connection::new(
@@ -1227,7 +1192,9 @@ impl Client {
.sink_map_err(|error| anyhow!(error)),
))
}
Http => {
"http" => {
rpc_url.set_scheme("ws").unwrap();
let request = request.uri(rpc_url.as_str()).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
Ok(Connection::new(
stream
@@ -1235,6 +1202,7 @@ impl Client {
.sink_map_err(|error| anyhow!(error)),
))
}
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
}
})
}
@@ -1383,7 +1351,7 @@ impl Client {
let mut url = self.rpc_url(http.clone(), None).await?;
url.set_path("/user");
url.set_query(Some(&format!("github_login={login}")));
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
let request = Request::get(url.as_str())
.header("Authorization", format!("token {api_token}"))
.body("".into())?;
@@ -1442,7 +1410,7 @@ impl Client {
self.peer.send(self.connection_id()?, message)
}
pub fn send_dynamic(&self, envelope: proto::Envelope) -> Result<()> {
fn send_dynamic(&self, envelope: proto::Envelope) -> Result<()> {
let connection_id = self.connection_id()?;
self.peer.send_dynamic(connection_id, envelope)
}
@@ -1815,7 +1783,7 @@ mod tests {
use clock::FakeSystemClock;
use gpui::{BackgroundExecutor, Context, TestAppContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use parking_lot::Mutex;
use proto::TypedEnvelope;
use settings::SettingsStore;

View File

@@ -6,7 +6,7 @@ use clock::SystemClock;
use collections::{HashMap, HashSet};
use futures::Future;
use gpui::{AppContext, BackgroundExecutor, Task};
use http_client::{self, HttpClient, HttpClientWithUrl, Method};
use http::{self, HttpClient, HttpClientWithUrl, Method};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use release_channel::ReleaseChannel;
@@ -18,7 +18,7 @@ use sysinfo::{CpuRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CpuEvent, EditEvent,
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent,
MemoryEvent, ReplEvent, SettingEvent,
MemoryEvent, SettingEvent,
};
use tempfile::NamedTempFile;
#[cfg(not(debug_assertions))]
@@ -531,21 +531,6 @@ impl Telemetry {
}
}
pub fn report_repl_event(
self: &Arc<Self>,
kernel_language: String,
kernel_status: String,
repl_session_id: String,
) {
let event = Event::Repl(ReplEvent {
kernel_language,
kernel_status,
repl_session_id,
});
self.report_event(event)
}
fn report_event(self: &Arc<Self>, event: Event) {
let mut state = self.state.lock();
@@ -647,7 +632,7 @@ impl Telemetry {
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
let request = http_client::Request::builder()
let request = http::Request::builder()
.method(Method::POST)
.uri(
this.http_client
@@ -676,7 +661,7 @@ mod tests {
use chrono::TimeZone;
use clock::FakeSystemClock;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use http::FakeHttpClient;
#[gpui::test]
fn test_telemetry_flush_on_max_queue_size(cx: &mut TestAppContext) {

View File

@@ -20,7 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-tungstenite.workspace = true
async-tungstenite = "0.16"
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
axum = { version = "0.6", features = ["json", "headers", "ws"] }
@@ -30,12 +30,12 @@ chrono.workspace = true
clock.workspace = true
clickhouse.workspace = true
collections.workspace = true
dashmap.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
http.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid.workspace = true
@@ -47,7 +47,7 @@ prost.workspace = true
rand.workspace = true
reqwest = { version = "0.11", features = ["json"] }
rpc.workspace = true
scrypt = "0.11"
scrypt = "0.7"
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
semantic_version.workspace = true
semver.workspace = true
@@ -79,7 +79,6 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
@@ -90,7 +89,6 @@ git_hosting_providers.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
live_kit_client = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
menu.workspace = true
@@ -101,13 +99,10 @@ pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
recent_projects = { workspace = true }
release_channel.workspace = true
remote = { workspace = true, features = ["test-support"] }
remote_server.workspace = true
dev_server_projects.workspace = true
rpc = { workspace = true, features = ["test-support"] }
sea-orm = { version = "0.12.x", features = ["sqlx-sqlite"] }
serde_json.workspace = true
session = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
sqlx = { version = "0.7", features = ["sqlite"] }
theme.workspace = true

138
crates/collab/src/ai.rs Normal file
View File

@@ -0,0 +1,138 @@
use anyhow::{anyhow, Context as _, Result};
use rpc::proto;
use util::ResultExt as _;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<open_ai::Request> {
Ok(open_ai::Request {
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
messages: request
.messages
.into_iter()
.map(|message: proto::LanguageModelRequestMessage| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
let openai_message = match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
content: message.content,
},
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::RequestMessage::Assistant {
content: Some(message.content),
tool_calls: message
.tool_calls
.into_iter()
.filter_map(|call| {
Some(open_ai::ToolCall {
id: call.id,
content: match call.variant? {
proto::tool_call::Variant::Function(f) => {
open_ai::ToolCallContent::Function {
function: open_ai::FunctionContent {
name: f.name,
arguments: f.arguments,
},
}
}
},
})
})
.collect(),
}
}
proto::LanguageModelRole::LanguageModelSystem => {
open_ai::RequestMessage::System {
content: message.content,
}
}
proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
tool_call_id: message
.tool_call_id
.ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
content: message.content,
},
};
Ok(openai_message)
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: request
.tools
.into_iter()
.filter_map(|tool| {
Some(match tool.variant? {
proto::chat_completion_tool::Variant::Function(f) => {
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: f.name,
description: f.description,
parameters: if let Some(params) = &f.parameters {
Some(
serde_json::from_str(params)
.context("failed to deserialize tool parameters")
.log_err()?,
)
} else {
None
},
},
}
}
})
})
.collect(),
tool_choice: request.tool_choice,
})
}
pub fn language_model_request_to_google_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<google_ai::GenerateContentRequest> {
Ok(google_ai::GenerateContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
})
}
pub fn language_model_request_message_to_google_ai(
message: proto::LanguageModelRequestMessage,
) -> Result<google_ai::Content> {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: message.content,
})],
role: match role {
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelTool => {
Err(anyhow!("we don't handle tool calls with google ai yet"))?
}
},
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {
Ok(google_ai::CountTokensRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
})
}

View File

@@ -1,4 +1,3 @@
pub mod contributors;
pub mod events;
pub mod extensions;
pub mod ips_file;
@@ -6,13 +5,13 @@ pub mod slack;
use crate::{
auth,
db::{User, UserId},
db::{ContributorSelector, User, UserId},
rpc, AppState, Error, Result,
};
use anyhow::anyhow;
use axum::{
body::Body,
extract::{Path, Query},
extract::{self, Path, Query},
http::{self, Request, StatusCode},
middleware::{self, Next},
response::IntoResponse,
@@ -20,6 +19,7 @@ use axum::{
Extension, Json, Router,
};
use axum_extra::response::ErasedJson;
use chrono::SecondsFormat;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower::ServiceBuilder;
@@ -31,7 +31,8 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
.route("/user", get(get_authenticated_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(contributors::router())
.route("/contributors", get(get_contributors).post(add_contributor))
.route("/contributor", get(check_is_contributor))
.layer(
ServiceBuilder::new()
.layer(Extension(state))
@@ -125,6 +126,66 @@ async fn get_rpc_server_snapshot(
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
}
async fn get_contributors(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<String>>> {
Ok(Json(app.db.get_contributors().await?))
}
#[derive(Debug, Deserialize)]
struct CheckIsContributorParams {
github_user_id: Option<i32>,
github_login: Option<String>,
}
impl CheckIsContributorParams {
fn as_contributor_selector(self) -> Result<ContributorSelector> {
if let Some(github_user_id) = self.github_user_id {
return Ok(ContributorSelector::GitHubUserId { github_user_id });
}
if let Some(github_login) = self.github_login {
return Ok(ContributorSelector::GitHubLogin { github_login });
}
Err(anyhow!(
"must be one of `github_user_id` or `github_login`."
))?
}
}
#[derive(Debug, Serialize)]
struct CheckIsContributorResponse {
signed_at: Option<String>,
}
async fn check_is_contributor(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<CheckIsContributorParams>,
) -> Result<Json<CheckIsContributorResponse>> {
let params = params.as_contributor_selector()?;
Ok(Json(CheckIsContributorResponse {
signed_at: app
.db
.get_contributor_sign_timestamp(&params)
.await?
.map(|ts| ts.and_utc().to_rfc3339_opts(SecondsFormat::Millis, true)),
}))
}
async fn add_contributor(
Extension(app): Extension<Arc<AppState>>,
extract::Json(params): extract::Json<AuthenticatedUserParams>,
) -> Result<()> {
let initial_channel_id = app.config.auto_join_channel_id;
app.db
.add_contributor(
&params.github_login,
params.github_user_id,
params.github_email.as_deref(),
initial_channel_id,
)
.await
}
#[derive(Deserialize)]
struct CreateAccessTokenQueryParams {
public_key: String,

View File

@@ -1,121 +0,0 @@
use std::sync::{Arc, OnceLock};
use anyhow::anyhow;
use axum::{
extract::{self, Query},
routing::get,
Extension, Json, Router,
};
use chrono::{NaiveDateTime, SecondsFormat};
use serde::{Deserialize, Serialize};
use crate::api::AuthenticatedUserParams;
use crate::db::ContributorSelector;
use crate::{AppState, Result};
pub fn router() -> Router {
Router::new()
.route("/contributors", get(get_contributors).post(add_contributor))
.route("/contributor", get(check_is_contributor))
}
async fn get_contributors(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<String>>> {
Ok(Json(app.db.get_contributors().await?))
}
#[derive(Debug, Deserialize)]
struct CheckIsContributorParams {
github_user_id: Option<i32>,
github_login: Option<String>,
}
impl CheckIsContributorParams {
fn as_contributor_selector(self) -> Result<ContributorSelector> {
if let Some(github_user_id) = self.github_user_id {
return Ok(ContributorSelector::GitHubUserId { github_user_id });
}
if let Some(github_login) = self.github_login {
return Ok(ContributorSelector::GitHubLogin { github_login });
}
Err(anyhow!(
"must be one of `github_user_id` or `github_login`."
))?
}
}
#[derive(Debug, Serialize)]
struct CheckIsContributorResponse {
signed_at: Option<String>,
}
async fn check_is_contributor(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<CheckIsContributorParams>,
) -> Result<Json<CheckIsContributorResponse>> {
let params = params.as_contributor_selector()?;
if RenovateBot::is_renovate_bot(&params) {
return Ok(Json(CheckIsContributorResponse {
signed_at: Some(
RenovateBot::created_at()
.and_utc()
.to_rfc3339_opts(SecondsFormat::Millis, true),
),
}));
}
Ok(Json(CheckIsContributorResponse {
signed_at: app
.db
.get_contributor_sign_timestamp(&params)
.await?
.map(|ts| ts.and_utc().to_rfc3339_opts(SecondsFormat::Millis, true)),
}))
}
/// The Renovate bot GitHub user (`renovate[bot]`).
///
/// https://api.github.com/users/renovate[bot]
struct RenovateBot;
impl RenovateBot {
const LOGIN: &'static str = "renovate[bot]";
const USER_ID: i32 = 29139614;
/// Returns the `created_at` timestamp for the Renovate bot user.
fn created_at() -> &'static NaiveDateTime {
static CREATED_AT: OnceLock<NaiveDateTime> = OnceLock::new();
CREATED_AT.get_or_init(|| {
chrono::DateTime::parse_from_rfc3339("2017-06-02T07:04:12Z")
.expect("failed to parse 'created_at' for 'renovate[bot]'")
.naive_utc()
})
}
/// Returns whether the given contributor selector corresponds to the Renovate bot user.
fn is_renovate_bot(contributor: &ContributorSelector) -> bool {
match contributor {
ContributorSelector::GitHubLogin { github_login } => github_login == Self::LOGIN,
ContributorSelector::GitHubUserId { github_user_id } => {
github_user_id == &Self::USER_ID
}
}
}
}
async fn add_contributor(
Extension(app): Extension<Arc<AppState>>,
extract::Json(params): extract::Json<AuthenticatedUserParams>,
) -> Result<()> {
let initial_channel_id = app.config.auto_join_channel_id;
app.db
.add_contributor(
&params.github_login,
params.github_user_id,
params.github_email.as_deref(),
initial_channel_id,
)
.await
}

View File

@@ -16,7 +16,7 @@ use sha2::{Digest, Sha256};
use std::sync::{Arc, OnceLock};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, CallEvent, CpuEvent, EditEvent, EditorEvent, Event,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent, ReplEvent,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent,
SettingEvent,
};
use uuid::Uuid;
@@ -518,13 +518,6 @@ pub async fn post_events(
checksum_matched,
))
}
Event::Repl(event) => to_upload.repl_events.push(ReplEventRow::from_event(
event.clone(),
&wrapper,
&request_body,
first_event_at,
checksum_matched,
)),
}
}
@@ -549,7 +542,6 @@ struct ToUpload {
extension_events: Vec<ExtensionEventRow>,
edit_events: Vec<EditEventRow>,
action_events: Vec<ActionEventRow>,
repl_events: Vec<ReplEventRow>,
}
impl ToUpload {
@@ -625,11 +617,6 @@ impl ToUpload {
.await
.with_context(|| format!("failed to upload to table '{ACTION_EVENTS_TABLE}'"))?;
const REPL_EVENTS_TABLE: &str = "repl_events";
Self::upload_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{REPL_EVENTS_TABLE}'"))?;
Ok(())
}
@@ -638,24 +625,22 @@ impl ToUpload {
rows: &[T],
clickhouse_client: &clickhouse::Client,
) -> anyhow::Result<()> {
if rows.is_empty() {
return Ok(());
if !rows.is_empty() {
let mut insert = clickhouse_client.insert(table)?;
for event in rows {
insert.write(event).await?;
}
insert.end().await?;
let event_count = rows.len();
log::info!(
"wrote {event_count} {event_specifier} to '{table}'",
event_specifier = if event_count == 1 { "event" } else { "events" }
);
}
let mut insert = clickhouse_client.insert(table)?;
for event in rows {
insert.write(event).await?;
}
insert.end().await?;
let event_count = rows.len();
log::info!(
"wrote {event_count} {event_specifier} to '{table}'",
event_specifier = if event_count == 1 { "event" } else { "events" }
);
Ok(())
}
}
@@ -1204,62 +1189,6 @@ impl ExtensionEventRow {
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct ReplEventRow {
// AppInfoBase
app_version: String,
major: Option<i32>,
minor: Option<i32>,
patch: Option<i32>,
checksum_matched: bool,
release_channel: String,
os_name: String,
os_version: String,
// ClientEventBase
installation_id: Option<String>,
session_id: Option<String>,
is_staff: Option<bool>,
time: i64,
// ReplEventRow
kernel_language: String,
kernel_status: String,
repl_session_id: String,
}
impl ReplEventRow {
fn from_event(
event: ReplEvent,
wrapper: &EventWrapper,
body: &EventRequestBody,
first_event_at: chrono::DateTime<chrono::Utc>,
checksum_matched: bool,
) -> Self {
let semver = body.semver();
let time =
first_event_at + chrono::Duration::milliseconds(wrapper.milliseconds_since_first_event);
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
checksum_matched,
release_channel: body.release_channel.clone().unwrap_or_default(),
os_name: body.os_name.clone(),
os_version: body.os_version.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
is_staff: body.is_staff,
time: time.timestamp_millis(),
kernel_language: event.kernel_language,
kernel_status: event.kernel_status,
repl_session_id: event.repl_session_id,
}
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct EditEventRow {
// AppInfoBase

View File

@@ -9,7 +9,6 @@ use axum::{
middleware::Next,
response::IntoResponse,
};
use base64::prelude::*;
use prometheus::{exponential_buckets, register_histogram, Histogram};
pub use rpc::auth::random_token;
use scrypt::{
@@ -156,27 +155,19 @@ pub async fn create_access_token(
/// protection.
pub fn hash_access_token(token: &str) -> String {
let digest = sha2::Sha256::digest(token);
format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
format!(
"$sha256${}",
base64::encode_config(digest, base64::URL_SAFE)
)
}
/// Encrypts the given access token with the given public key to avoid leaking it on the way
/// to the client.
pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
use rpc::auth::EncryptionFormat;
/// The encryption format to use for the access token.
///
/// Currently we're using the original encryption format to avoid
/// breaking compatibility with older clients.
///
/// Once enough clients are capable of decrypting the newer encryption
/// format we can start encrypting with `EncryptionFormat::V1`.
const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
let native_app_public_key =
rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
let encrypted_access_token = native_app_public_key
.encrypt_string(access_token, ENCRYPTION_FORMAT)
.encrypt_string(access_token)
.context("failed to encrypt access token with public key")?;
Ok(encrypted_access_token)
}
@@ -400,16 +391,15 @@ mod test {
fn previous_hash_access_token(token: &str) -> Result<String> {
// Avoid slow hashing in debug mode.
let params = if cfg!(debug_assertions) {
scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
scrypt::Params::new(1, 1, 1).unwrap()
} else {
scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
scrypt::Params::new(14, 8, 1).unwrap()
};
Ok(Scrypt
.hash_password_customized(
.hash_password(
token.as_bytes(),
None,
None,
params,
&SaltString::generate(thread_rng()),
)

View File

@@ -1,3 +1,4 @@
pub mod ai;
pub mod api;
pub mod auth;
pub mod db;

View File

@@ -153,7 +153,7 @@ async fn main() -> Result<()> {
let signal = async move {
// todo(windows):
// `ctrl_close` does not work well, because tokio's signal handler always returns soon,
// but system terminates the application soon after returning CTRL+CLOSE handler.
// but system termiates the application soon after returning CTRL+CLOSE handler.
// So we should implement blocking handler to treat CTRL+CLOSE signal.
let mut ctrl_break = tokio::signal::windows::ctrl_break()
.expect("failed to listen for interrupt signal");

View File

@@ -10,9 +10,9 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Config, Error, RateLimit, RateLimiter, Result,
AppState, Error, RateLimit, RateLimiter, Result,
};
use anyhow::{anyhow, bail, Context as _};
use anyhow::{anyhow, Context as _};
use async_tungstenite::tungstenite::{
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
};
@@ -42,12 +42,12 @@ use futures::{
stream::FuturesUnordered,
FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use http_client::IsahcHttpClient;
use http::IsahcHttpClient;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
@@ -605,40 +605,29 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
complete_with_language_model(request, response, session, &app_state.config)
.await
}
}
})
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
stream_complete_with_language_model(
request,
response,
session,
&app_state.config,
)
.await
}
complete_with_language_model(
request,
response,
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
app_state.config.anthropic_api_key.clone(),
)
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
}
user_handler(move |request, response, session| {
count_tokens_with_language_model(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
})
})
.add_request_handler({
user_handler(move |request, response, session| {
@@ -1403,7 +1392,7 @@ pub async fn handle_websocket_request(
let socket = socket
.map_ok(to_tungstenite_message)
.err_into()
.with(|message| async move { to_axum_message(message) });
.with(|message| async move { Ok(to_axum_message(message)) });
let connection = Connection::new(Box::pin(socket));
async move {
server
@@ -4526,171 +4515,284 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
async fn complete_with_language_model(
request: proto::CompleteWithLanguageModel,
response: Response<proto::CompleteWithLanguageModel>,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
config: &Config,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
anthropic_api_key: Option<Arc<str>>,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
anthropic::complete(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CompleteWithLanguageModelResponse {
completion: serde_json::to_string(&result)?,
})?;
Ok(())
}
async fn stream_complete_with_language_model(
request: proto::StreamCompleteWithLanguageModel,
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = chunks.next().await {
let chunk = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let mut events = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
None => return Err(anyhow!("unknown provider"))?,
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
}
Ok(())
}
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
async fn complete_with_open_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
let mut completion_stream = open_ai::stream_completion(
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
crate::ai::language_model_request_to_open_ai(request)?,
None,
)
.await
.context("open_ai::stream_completion request failed within collab")?;
session
.rate_limiter
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
.await?;
while let Some(event) = completion_stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.choices
.into_iter()
.map(|choice| proto::LanguageModelChoiceDelta {
index: choice.index,
delta: Some(proto::LanguageModelResponseMessage {
role: choice.delta.role.map(|role| match role {
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
} as i32),
content: choice.delta.content,
tool_calls: choice
.delta
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|delta| proto::ToolCallDelta {
index: delta.index as u32,
id: delta.id,
variant: match delta.function {
Some(function) => {
let name = function.name;
let arguments = function.arguments;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Some(proto::tool_call_delta::Variant::Function(
proto::tool_call_delta::FunctionCallDelta {
name,
arguments,
},
))
}
None => None,
},
})
.collect(),
}),
finish_reason: choice.finish_reason,
})
.collect(),
})?;
}
Ok(())
}
struct CountLanguageModelTokensRateLimit;
async fn complete_with_google_ai(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let mut stream = google_ai::stream_generate_content(
session.http_client.clone(),
google_ai::API_URL,
api_key.as_ref(),
&request.model.clone(),
crate::ai::language_model_request_to_google_ai(request)?,
)
.await
.context("google_ai::stream_generate_content request failed")?;
impl RateLimit for CountLanguageModelTokensRateLimit {
while let Some(event) = stream.next().await {
let event = event?;
response.send(proto::LanguageModelResponse {
choices: event
.candidates
.unwrap_or_default()
.into_iter()
.map(|candidate| proto::LanguageModelChoiceDelta {
index: candidate.index as u32,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(match candidate.content.role {
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
} as i32),
content: Some(
candidate
.content
.parts
.into_iter()
.filter_map(|part| match part {
google_ai::Part::TextPart(part) => Some(part.text),
google_ai::Part::InlineDataPart(_) => None,
})
.collect(),
),
// Tool calls are not supported for Google
tool_calls: Vec::new(),
}),
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
})
.collect(),
})?;
}
Ok(())
}
async fn complete_with_anthropic(
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
let model = anthropic::Model::from_id(&request.model)?;
let mut system_message = String::new();
let messages = request
.messages
.into_iter()
.filter_map(|message| {
match message.role() {
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
role: anthropic::Role::User,
content: message.content,
}),
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
role: anthropic::Role::Assistant,
content: message.content,
}),
// Anthropic's API breaks system instructions out as a separate field rather
// than having a system message role.
LanguageModelRole::LanguageModelSystem => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.content);
None
}
// We don't yet support tool calls for Anthropic
LanguageModelRole::LanguageModelTool => None,
}
})
.collect();
let mut stream = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
anthropic::Request {
model,
messages,
stream: true,
system: system_message,
max_tokens: 4092,
},
None,
)
.await?;
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
while let Some(event) = stream.next().await {
let event = event?;
match event {
anthropic::ResponseEvent::MessageStart { message } => {
if let Some(role) = message.role {
if role == "assistant" {
current_role = proto::LanguageModelRole::LanguageModelAssistant;
} else if role == "user" {
current_role = proto::LanguageModelRole::LanguageModelUser;
}
}
}
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
match content_block {
anthropic::ContentBlock::Text { text } => {
if !text.is_empty() {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
}
}
}
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
anthropic::TextDelta::TextDelta { text } => {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: Some(proto::LanguageModelResponseMessage {
role: Some(current_role as i32),
content: Some(text),
tool_calls: Vec::new(),
}),
finish_reason: None,
}],
})?;
}
},
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason {
response.send(proto::LanguageModelResponse {
choices: vec![proto::LanguageModelChoiceDelta {
index: 0,
delta: None,
finish_reason: Some(stop_reason),
}],
})?;
}
}
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
anthropic::ResponseEvent::MessageStop {} => {}
anthropic::ResponseEvent::Ping {} => {}
}
}
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
fn capacity() -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
@@ -4701,10 +4803,45 @@ impl RateLimit for CountLanguageModelTokensRateLimit {
}
fn db_name() -> &'static str {
"count-language-model-tokens"
"count-tokens-with-language-model"
}
}
async fn count_tokens_with_language_model(
request: proto::CountTokensWithLanguageModel,
response: Response<proto::CountTokensWithLanguageModel>,
session: UserSession,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"counting tokens for model: {:?} is not supported",
request.model
))?;
}
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
crate::ai::count_tokens_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CountTokensResponse {
token_count: tokens_response.total_tokens as u32,
})?;
Ok(())
}
struct ComputeEmbeddingsRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {
@@ -4992,8 +5129,8 @@ async fn get_private_user_info(
Ok(())
}
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
let message = match message {
fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
match message {
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
@@ -5002,20 +5139,7 @@ fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
code: frame.code.into(),
reason: frame.reason,
})),
// We should never receive a frame while reading the message, according
// to the `tungstenite` maintainers:
//
// > It cannot occur when you read messages from the WebSocket, but it
// > can be used when you want to send the raw frames (e.g. you want to
// > send the frames to the WebSocket without composing the full message first).
// >
// > — https://github.com/snapview/tungstenite-rs/issues/268
TungsteniteMessage::Frame(_) => {
bail!("received an unexpected frame while reading the message")
}
};
Ok(message)
}
}
fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {

View File

@@ -16,7 +16,6 @@ mod notification_tests;
mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod test_server;
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};

View File

@@ -52,7 +52,7 @@ async fn test_channel_guests(
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
assert!(project_b
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.create_entry((worktree_id, "b.txt"), false, cx)
})
.await

View File

@@ -76,7 +76,7 @@ async fn test_host_disconnect(
let active_call_a = cx_a.read(ActiveCall::global);
let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
@@ -1144,7 +1144,7 @@ async fn test_share_project(
});
project_b.read_with(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap().read(cx);
let worktree = project.worktrees().next().unwrap().read(cx);
assert_eq!(
worktree.paths().map(AsRef::as_ref).collect::<Vec<_>>(),
[
@@ -1158,7 +1158,7 @@ async fn test_share_project(
project_b
.update(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap();
let worktree = project.worktrees().next().unwrap();
let entry = worktree.read(cx).entry_for_path("ignored-dir").unwrap();
project.expand_entry(worktree_id, entry.id, cx).unwrap()
})
@@ -1166,7 +1166,7 @@ async fn test_share_project(
.unwrap();
project_b.read_with(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap().read(cx);
let worktree = project.worktrees().next().unwrap().read(cx);
assert_eq!(
worktree.paths().map(AsRef::as_ref).collect::<Vec<_>>(),
[

View File

@@ -1,4 +1,3 @@
#![allow(clippy::reversed_empty_ranges)]
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use call::{ActiveCall, ParticipantLocation};
use client::ChannelId;

View File

@@ -18,9 +18,7 @@ use gpui::{
TestAppContext, UpdateGlobal,
};
use language::{
language_settings::{
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
},
language_settings::{AllLanguageSettings, Formatter, PrettierSettings},
tree_sitter_rust, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, LanguageConfig,
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
};
@@ -1377,7 +1375,7 @@ async fn test_unshare_project(
.await
.unwrap();
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
executor.run_until_parked();
@@ -1505,8 +1503,7 @@ async fn test_project_reconnect(
let (project_a1, _) = client_a.build_local_project("/root-1/dir1", cx_a).await;
let (project_a2, _) = client_a.build_local_project("/root-2", cx_a).await;
let (project_a3, _) = client_a.build_local_project("/root-3", cx_a).await;
let worktree_a1 =
project_a1.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a1 = project_a1.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project1_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a1.clone(), cx))
.await
@@ -2309,7 +2306,7 @@ async fn test_propagate_saves_and_fs_changes(
.await;
let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
let worktree_a = project_a.read_with(cx_a, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |p, _| p.worktrees().next().unwrap());
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
@@ -2319,9 +2316,9 @@ async fn test_propagate_saves_and_fs_changes(
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let project_c = client_c.build_dev_server_project(project_id, cx_c).await;
let worktree_b = project_b.read_with(cx_b, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |p, _| p.worktrees().next().unwrap());
let worktree_c = project_c.read_with(cx_c, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_c = project_c.read_with(cx_c, |p, _| p.worktrees().next().unwrap());
// Open and edit a buffer as both guests B and C.
let buffer_b = project_b
@@ -3023,8 +3020,8 @@ async fn test_fs_operations(
.unwrap();
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap());
let entry = project_b
.update(cx_b, |project, cx| {
@@ -3324,7 +3321,7 @@ async fn test_local_settings(
// As client B, join that project and observe the local settings.
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_b = project_b.read_with(cx_b, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap());
executor.run_until_parked();
cx_b.read(|cx| {
let store = cx.global::<SettingsStore>();
@@ -3736,7 +3733,7 @@ async fn test_leaving_project(
// Client B opens a buffer.
let buffer_b1 = project_b1
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.open_buffer((worktree_id, "a.txt"), cx)
})
.await
@@ -3774,7 +3771,7 @@ async fn test_leaving_project(
let buffer_b2 = project_b2
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.open_buffer((worktree_id, "a.txt"), cx)
})
.await
@@ -4412,13 +4409,10 @@ async fn test_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::External {
command: "awk".into(),
arguments: vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
}]
.into(),
)));
file.defaults.formatter = Some(Formatter::External {
command: "awk".into(),
arguments: vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
});
});
});
});
@@ -4499,7 +4493,7 @@ async fn test_prettier_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::Auto);
file.defaults.formatter = Some(Formatter::Auto);
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()
@@ -4510,9 +4504,7 @@ async fn test_prettier_formatting_buffer(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::LanguageServer { name: None }].into(),
)));
file.defaults.formatter = Some(Formatter::LanguageServer);
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()
@@ -4628,7 +4620,7 @@ async fn test_definition(
.unwrap();
cx_b.read(|cx| {
assert_eq!(definitions_1.len(), 1);
assert_eq!(project_b.read(cx).worktrees(cx).count(), 2);
assert_eq!(project_b.read(cx).worktrees().count(), 2);
let target_buffer = definitions_1[0].target.buffer.read(cx);
assert_eq!(
target_buffer.text(),
@@ -4657,7 +4649,7 @@ async fn test_definition(
.unwrap();
cx_b.read(|cx| {
assert_eq!(definitions_2.len(), 1);
assert_eq!(project_b.read(cx).worktrees(cx).count(), 2);
assert_eq!(project_b.read(cx).worktrees().count(), 2);
let target_buffer = definitions_2[0].target.buffer.read(cx);
assert_eq!(
target_buffer.text(),
@@ -4815,7 +4807,7 @@ async fn test_references(
assert!(status.pending_work.is_empty());
assert_eq!(references.len(), 3);
assert_eq!(project.worktrees(cx).count(), 2);
assert_eq!(project.worktrees().count(), 2);
let two_buffer = references[0].buffer.read(cx);
let three_buffer = references[2].buffer.read(cx);
@@ -6200,7 +6192,7 @@ async fn test_preview_tabs(cx: &mut TestAppContext) {
let project = workspace.update(cx, |workspace, _| workspace.project().clone());
let worktree_id = project.update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
project.worktrees().next().unwrap().read(cx).id()
});
let path_1 = ProjectPath {

View File

@@ -301,7 +301,7 @@ impl RandomizedTest for ProjectCollaborationTest {
let is_local = project.read_with(cx, |project, _| project.is_local());
let worktree = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.worktrees()
.filter(|worktree| {
let worktree = worktree.read(cx);
worktree.is_visible()
@@ -423,7 +423,7 @@ impl RandomizedTest for ProjectCollaborationTest {
81.. => {
let worktree = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.worktrees()
.filter(|worktree| {
let worktree = worktree.read(cx);
worktree.is_visible()
@@ -1172,7 +1172,7 @@ impl RandomizedTest for ProjectCollaborationTest {
let host_worktree_snapshots =
host_project.read_with(host_cx, |host_project, cx| {
host_project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
(worktree.id(), worktree.snapshot())
@@ -1180,7 +1180,7 @@ impl RandomizedTest for ProjectCollaborationTest {
.collect::<BTreeMap<_, _>>()
});
let guest_worktree_snapshots = guest_project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
(worktree.id(), worktree.snapshot())
@@ -1538,7 +1538,7 @@ fn project_path_for_full_path(
let root_name = components.next().unwrap().as_os_str().to_str().unwrap();
let path = components.as_path().into();
let worktree_id = project.read_with(cx, |project, cx| {
project.worktrees(cx).find_map(|worktree| {
project.worktrees().find_map(|worktree| {
let worktree = worktree.read(cx);
if worktree.root_name() == root_name {
Some(worktree.id())

View File

@@ -1,102 +0,0 @@
use crate::tests::TestServer;
use call::ActiveCall;
use fs::{FakeFs, Fs as _};
use gpui::{Context as _, TestAppContext};
use remote::SshSession;
use remote_server::HeadlessProject;
use serde_json::json;
use std::{path::Path, sync::Arc};
#[gpui::test]
async fn test_sharing_an_ssh_remote_project(
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
server_cx: &mut TestAppContext,
) {
let executor = cx_a.executor();
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
.create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
.await;
// Set up project on remote FS
let (client_ssh, server_ssh) = SshSession::fake(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
"/code",
json!({
"project1": {
"README.md": "# project 1",
"src": {
"lib.rs": "fn one() -> usize { 1 }"
}
},
"project2": {
"README.md": "# project 2",
},
}),
)
.await;
// User A connects to the remote project via SSH.
server_cx.update(HeadlessProject::init);
let _headless_project =
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
let (project_a, worktree_id) = client_a
.build_ssh_project("/code/project1", client_ssh, cx_a)
.await;
// User A shares the remote project.
let active_call_a = cx_a.read(ActiveCall::global);
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
.unwrap();
// User B joins the project.
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_b = project_b
.update(cx_b, |project, cx| project.worktree_for_id(worktree_id, cx))
.unwrap();
executor.run_until_parked();
worktree_b.update(cx_b, |worktree, _cx| {
assert_eq!(
worktree.paths().map(Arc::as_ref).collect::<Vec<_>>(),
vec![
Path::new("README.md"),
Path::new("src"),
Path::new("src/lib.rs"),
]
);
});
// User B can open buffers in the remote project.
let buffer_b = project_b
.update(cx_b, |project, cx| {
project.open_buffer((worktree_id, "src/lib.rs"), cx)
})
.await
.unwrap();
buffer_b.update(cx_b, |buffer, cx| {
assert_eq!(buffer.text(), "fn one() -> usize { 1 }");
let ix = buffer.text().find('1').unwrap();
buffer.edit([(ix..ix + 1, "100")], None, cx);
});
project_b
.update(cx_b, |project, cx| project.save_buffer(buffer_b, cx))
.await
.unwrap();
assert_eq!(
remote_fs
.load("/code/project1/src/lib.rs".as_ref())
.await
.unwrap(),
"fn one() -> usize { 100 }"
);
}

View File

@@ -19,20 +19,18 @@ use fs::FakeFs;
use futures::{channel::oneshot, StreamExt as _};
use git::GitHostingProviderRegistry;
use gpui::{BackgroundExecutor, Context, Model, Task, TestAppContext, View, VisualTestContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use language::LanguageRegistry;
use node_runtime::FakeNodeRuntime;
use notifications::NotificationStore;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use remote::SshSession;
use rpc::{
proto::{self, ChannelRole},
RECEIVE_TIMEOUT,
};
use semantic_version::SemanticVersion;
use serde_json::json;
use session::Session;
use settings::SettingsStore;
use std::{
cell::{Ref, RefCell, RefMut},
@@ -157,8 +155,6 @@ impl TestServer {
}
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
let fs = FakeFs::new(cx.executor());
cx.update(|cx| {
if cx.has_global::<SettingsStore>() {
panic!("Same cx used to create two test clients")
@@ -267,6 +263,7 @@ impl TestServer {
git_hosting_provider_registry
.register_hosting_provider(Arc::new(git_hosting_providers::Github));
let fs = FakeFs::new(cx.executor());
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx));
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
@@ -278,7 +275,6 @@ impl TestServer {
fs: fs.clone(),
build_window_options: |_, _| Default::default(),
node_runtime: FakeNodeRuntime::new(),
session: Session::test(),
});
let os_keymap = "keymaps/default-macos.json";
@@ -298,8 +294,7 @@ impl TestServer {
menu::init();
dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
language_model::LanguageModelRegistry::test(cx);
completion::init(cx);
assistant::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client);
});
@@ -407,7 +402,6 @@ impl TestServer {
fs: fs.clone(),
build_window_options: |_, _| Default::default(),
node_runtime: FakeNodeRuntime::new(),
session: Session::test(),
});
cx.update(|cx| {
@@ -820,30 +814,6 @@ impl TestClient {
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub async fn build_ssh_project(
&self,
root_path: impl AsRef<Path>,
ssh: Arc<SshSession>,
cx: &mut TestAppContext,
) -> (Model<Project>, WorktreeId) {
let project = cx.update(|cx| {
Project::ssh(
ssh,
self.client().clone(),
self.app_state.node_runtime.clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
});
let (worktree, _) = project
.update(cx, |p, cx| p.find_or_create_worktree(root_path, true, cx))
.await
.unwrap();
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub async fn build_test_project(&self, cx: &mut TestAppContext) -> Model<Project> {
self.fs()
.insert_tree(

View File

@@ -25,7 +25,7 @@ test-support = [
"settings/test-support",
"util/test-support",
"workspace/test-support",
"http_client/test-support",
"http/test-support",
]
[dependencies]
@@ -78,7 +78,7 @@ pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
tree-sitter-md.workspace = true
tree-sitter-markdown.workspace = true
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }

View File

@@ -1107,11 +1107,9 @@ impl Panel for ChatPanel {
}
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
settings::update_settings_file::<ChatPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
);
settings::update_settings_file::<ChatPanelSettings>(self.fs.clone(), cx, move |settings| {
settings.dock = Some(position)
});
}
fn size(&self, cx: &gpui::WindowContext) -> Pixels {

View File

@@ -6,7 +6,7 @@ use editor::{AnchorRangeExt, CompletionProvider, Editor, EditorElement, EditorSt
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
AsyncWindowContext, FocusableView, FontStyle, FontWeight, HighlightStyle, IntoElement, Model,
Render, Task, TextStyle, View, ViewContext, WeakView,
Render, Task, TextStyle, View, ViewContext, WeakView, WhiteSpace,
};
use language::{
language_settings::SoftWrap, Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry,
@@ -533,12 +533,14 @@ impl Render for MessageEditor {
},
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: TextSize::Small.rems(cx).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
div()

View File

@@ -16,7 +16,7 @@ use gpui::{
EventEmitter, FocusHandle, FocusableView, FontStyle, InteractiveElement, IntoElement,
ListOffset, ListState, Model, MouseDownEvent, ParentElement, Pixels, Point, PromptLevel,
Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
WeakView,
WeakView, WhiteSpace,
};
use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrev};
use project::{Fs, Project};
@@ -2190,12 +2190,14 @@ impl CollabPanel {
},
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
@@ -2807,7 +2809,7 @@ impl Panel for CollabPanel {
settings::update_settings_file::<CollaborationPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
move |settings| settings.dock = Some(position),
);
}

View File

@@ -672,7 +672,7 @@ impl Panel for NotificationPanel {
settings::update_settings_file::<NotificationPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
move |settings| settings.dock = Some(position),
);
}

View File

@@ -17,10 +17,9 @@ use gpui::{
use picker::{Picker, PickerDelegate};
use postage::{sink::Sink, stream::Stream};
use settings::Settings;
use ui::{h_flex, prelude::*, v_flex, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing};
use util::ResultExt;
use workspace::{ModalView, Workspace, WorkspaceSettings};
use workspace::{ModalView, Workspace};
use zed_actions::OpenZedUrl;
actions!(command_palette, [Toggle]);
@@ -249,13 +248,9 @@ impl PickerDelegate for CommandPaletteDelegate {
fn update_matches(
&mut self,
mut query: String,
query: String,
cx: &mut ViewContext<Picker<Self>>,
) -> gpui::Task<()> {
let settings = WorkspaceSettings::get_global(cx);
if let Some(alias) = settings.command_aliases.get(&query) {
query = alias.to_string();
}
let (mut tx, mut rx) = postage::dispatch::channel(1);
let task = cx.background_executor().spawn({
let mut commands = self.all_commands.clone();

View File

@@ -1,45 +0,0 @@
[package]
name = "completion"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/completion.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"language_model/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -1,309 +0,0 @@
use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelTool,
};
use smol::{future::FutureExt, lock::{Semaphore, SemaphoreGuardArc}};
use std::{future, pin::Pin, sync::Arc, task::Poll};
use ui::Context;
pub fn init(cx: &mut AppContext) {
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
}
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
impl Global for GlobalLanguageModelCompletionProvider {}
pub struct LanguageModelCompletionProvider {
active_provider: Option<Arc<dyn LanguageModelProvider>>,
active_model: Option<Arc<dyn LanguageModel>>,
request_limiter: Arc<Semaphore>,
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct LanguageModelCompletionResponse {
inner: BoxStream<'static, Result<String>>,
_lock: SemaphoreGuardArc,
}
impl futures::Stream for LanguageModelCompletionResponse {
type Item = Result<String>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl LanguageModelCompletionProvider {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.clone()
}
pub fn read_global(cx: &AppContext) -> &Self {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.read(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut AppContext) {
let provider = cx.new_model(|cx| {
let mut this = Self::new(cx);
let available_model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.unwrap()
.clone();
this.set_active_model(available_model, cx);
this
});
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
}
pub fn new(cx: &mut ModelContext<Self>) -> Self {
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
cx.notify();
})
.detach();
Self {
active_provider: None,
active_model: None,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(
&mut self,
provider_id: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
self.active_model = None;
cx.notify();
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.clone()
}
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| {
m.id() == model.id() && m.provider_id() == model.provider_id()
}) {
return;
}
self.active_provider =
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
self.active_model = Some(model.clone());
if let Some(provider) = self.active_provider.as_ref() {
provider.load_model(model, cx);
}
cx.notify();
}
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
self.active_provider
.as_ref()
.map_or(false, |provider| provider.is_authenticated(cx))
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| {
provider.reset_credentials(cx)
})
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() {
model.count_tokens(request, cx)
} else {
future::ready(Err(anyhow!("no active model"))).boxed()
}
}
pub fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<LanguageModelCompletionResponse>> {
if let Some(language_model) = self.active_model() {
let rate_limiter = self.request_limiter.clone();
cx.spawn(|cx| async move {
let lock = rate_limiter.acquire_arc().await;
let response = language_model.stream_completion(request, &cx).await?;
Ok(LanguageModelCompletionResponse {
inner: response,
_lock: lock,
})
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
let response = self.stream_completion(request, cx);
cx.foreground_executor().spawn(async move {
let mut chunks = response.await?;
let mut completion = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
completion.push_str(&chunk);
}
Ok(completion)
})
}
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<T>> {
if let Some(language_model) = self.active_model() {
cx.spawn(|cx| async move {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request =
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
let response = request.await?;
Ok(serde_json::from_value(response)?)
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn active_model_telemetry_id(&self) -> Option<String> {
self.active_model.as_ref().map(|m| m.telemetry_id())
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use gpui::AppContext;
use settings::SettingsStore;
use ui::Context;
use crate::{
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
};
use language_model::LanguageModelRegistry;
#[gpui::test]
fn test_rate_limiting(cx: &mut AppContext) {
SettingsStore::test(cx);
let fake_provider = LanguageModelRegistry::test(cx);
let model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.cloned()
.unwrap();
let provider = cx.new_model(|cx| {
let mut provider = LanguageModelCompletionProvider::new(cx);
provider.set_active_model(model.clone(), cx);
provider
});
let fake_model = fake_provider.test_model();
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.read(cx).stream_completion(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let mut stream = response.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_model.pending_completions().into_iter().next().unwrap();
fake_model.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
assert_eq!(fake_model.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_model.completion_count(), 0);
}
}

View File

@@ -32,7 +32,7 @@ command_palette_hooks.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
language.workspace = true
lsp.workspace = true
menu.workspace = true
@@ -65,4 +65,4 @@ rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

Some files were not shown because too many files have changed in this diff Show More