Compare commits

..

6 Commits

Author SHA1 Message Date
Peter Tripp
a0d4c82c0c Another bad path. 2024-10-13 22:11:50 -04:00
Peter Tripp
8e6fac76df Fix two asset loads. 2024-10-13 21:57:44 -04:00
Peter Tripp
a027217629 Merge branch 'main' into ibm_plex 2024-10-11 10:34:22 -04:00
Peter Tripp
6183b02db0 fonts.md 2024-08-31 01:04:58 -04:00
Peter Tripp
ced0d608ca Newline 2024-08-31 01:00:13 -04:00
Peter Tripp
4a9e15b1b1 Switch to upstream IBM Plex fonts.
- Update [@ibm/plex-sans@1.0.0](https://github.com/IBM/plex/releases/tag/%40ibm/plex-sans%401.0.0)
- Update [@ibm/plex-mono@1.0.0](https://github.com/IBM/plex/releases/tag/%40ibm/plex-mono%401.0.0)

```
ca403c56931baef307d20ba64b69acb71abcad61f75e66414661d57484b690ec  ibm-plex-mono/IBMPlexMono-Bold.ttf
0e45a5a540992163229d2a29662553f313fab391757ca2ab3dc8f4e0d9be0979  ibm-plex-mono/IBMPlexMono-BoldItalic.ttf
8ebe04c8c6cc82f0be19896ddc61d9935cdd0f027b0173c1945b8d247d7dfc2a  ibm-plex-mono/IBMPlexMono-Italic.ttf
fe11304a5fe956d5744e9b6a246cc83d90425245e75a62230044966ca96a7f50  ibm-plex-mono/IBMPlexMono-Regular.ttf
9e6c74a889a700d707613d24548fe4ffa6bc59559a0689d2cf9e133bdcdafb2f  ibm-plex-sans/IBMPlexSans-Bold.ttf
0e3142ba2ef31fe5c02f0c6c36424f251609cd6b73880076e21c2e81931ba2b9  ibm-plex-sans/IBMPlexSans-BoldItalic.ttf
a9c6ef9942c49e49d11e11a6dacc0b3a087978757e9b22a06b8ac22a6400fb15  ibm-plex-sans/IBMPlexSans-Italic.ttf
975dcda37d80f038dcd143c22e33ca2d97a0cc5a929aace1c749153b0fe1afa5  ibm-plex-sans/IBMPlexSans-Regular.ttf
```
2024-08-31 00:58:48 -04:00
364 changed files with 9376 additions and 12476 deletions

View File

@@ -7,3 +7,9 @@ runs:
- name: cargo fmt
shell: bash -euxo pipefail {0}
run: cargo fmt --all -- --check
- name: Find modified migrations
shell: bash -euxo pipefail {0}
run: |
export SQUAWK_GITHUB_TOKEN=${{ github.token }}
. ./script/squawk

View File

@@ -2,4 +2,14 @@ Closes #ISSUE
Release Notes:
- N/A *or* Added/Fixed/Improved ...
- Added/Fixed/Improved ...
Optionally, include screenshots / media showcasing your addition that can be included in the release notes.
### Or...
Closes #ISSUE
Release Notes:
- N/A

View File

@@ -14,7 +14,6 @@ on:
- "**"
paths-ignore:
- "docs/**"
- ".github/workflows/community_*"
concurrency:
# Allow only one workflow per any non-`main` branch.
@@ -27,10 +26,9 @@ env:
RUST_BACKTRACE: 1
jobs:
migration_checks:
name: Check Postgres and Protobuf migrations, mergability
if: github.repository_owner == 'zed-industries'
style:
timeout-minutes: 60
name: Check formatting and spelling
runs-on:
- self-hosted
- test
@@ -39,16 +37,25 @@ jobs:
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
clean: false
fetch-depth: 0 # fetch full history
fetch-depth: 0
- name: Remove untracked files
run: git clean -df
- name: Find modified migrations
shell: bash -euxo pipefail {0}
run: |
export SQUAWK_GITHUB_TOKEN=${{ github.token }}
. ./script/squawk
- name: Check spelling
run: script/check-spelling
- name: Run style checks
uses: ./.github/actions/check_style
- name: Check unused dependencies
uses: bnjbvr/cargo-machete@main
- name: Check licenses are present
run: script/check-licenses
- name: Check license generation
run: script/generate-licenses /tmp/zed_licenses_output
- name: Ensure fresh merge
shell: bash -euxo pipefail {0}
@@ -70,24 +77,6 @@ jobs:
input: "crates/proto/proto/"
against: "https://github.com/${GITHUB_REPOSITORY}.git#branch=${BUF_BASE_BRANCH},subdir=crates/proto/proto/"
style:
timeout-minutes: 60
name: Check formatting and spelling
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-8vcpu-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
- name: Run style checks
uses: ./.github/actions/check_style
- name: Check for typos
uses: crate-ci/typos@v1.24.6
with:
config: ./typos.toml
macos_tests:
timeout-minutes: 60
name: (macOS) Run Clippy and tests
@@ -103,25 +92,14 @@ jobs:
- name: cargo clippy
run: ./script/clippy
- name: Check unused dependencies
uses: bnjbvr/cargo-machete@main
- name: Check licenses
run: |
script/check-licenses
script/generate-licenses /tmp/zed_licenses_output
- name: Run tests
uses: ./.github/actions/run_tests
- name: Build collab
run: RUSTFLAGS="-D warnings" cargo build -p collab
run: cargo build -p collab
- name: Build other binaries and features
run: |
RUSTFLAGS="-D warnings" cargo build --workspace --bins --all-features
cargo check -p gpui --features "macos-blade"
RUSTFLAGS="-D warnings" cargo build -p remote_server
run: cargo build --workspace --bins --all-features; cargo check -p gpui --features "macos-blade"
linux_tests:
timeout-minutes: 60
@@ -138,7 +116,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@82a92a6e8fbeee089604da2575dc567ae9ddeaab # v2
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
@@ -153,33 +131,7 @@ jobs:
uses: ./.github/actions/run_tests
- name: Build Zed
run: RUSTFLAGS="-D warnings" cargo build -p zed
build_remote_server:
timeout-minutes: 60
name: (Linux) Build Remote Server
runs-on:
- buildjet-16vcpu-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Checkout repo
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
with:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@82a92a6e8fbeee089604da2575dc567ae9ddeaab # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
- name: Install Clang & Mold
run: ./script/remote-server && ./script/install-mold 2.34.0
- name: Build Remote Server
run: RUSTFLAGS="-D warnings" cargo build -p remote_server
run: cargo build -p zed
# todo(windows): Actually run the tests
windows_tests:
@@ -193,7 +145,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@82a92a6e8fbeee089604da2575dc567ae9ddeaab # v2
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "github"
@@ -203,7 +155,7 @@ jobs:
run: cargo xtask clippy
- name: Build Zed
run: $env:RUSTFLAGS="-D warnings"; cargo build
run: cargo build -p zed
bundle-mac:
timeout-minutes: 60
@@ -267,20 +219,20 @@ jobs:
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
- name: Upload app bundle (universal) to workflow run if main branch or specific label
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
path: target/release/Zed.dmg
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
@@ -331,7 +283,7 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
@@ -378,7 +330,7 @@ jobs:
run: script/bundle-linux
- name: Upload Linux bundle to workflow run if main branch or specific label
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
with:
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz

View File

@@ -17,7 +17,7 @@ jobs:
We're working to clean up our issue tracker by closing older issues that might not be relevant anymore. Are you able to reproduce this issue in the latest version of Zed? If so, please let us know by commenting on this issue and we will keep it open; otherwise, we'll close it in 7 days. Feel free to open a new issue if you're seeing this message after the issue has been closed.
Thanks for your help!
close-issue-message: "This issue was closed due to inactivity. If you're still experiencing this problem, please open a new issue with a link to this issue."
close-issue-message: "This issue was closed due to inactivity. If you're still experiencing this problem, feel free to ping a Zed team member to reopen this issue or open a new one."
# We will increase `days-before-stale` to 365 on or after Jan 24th,
# 2024. This date marks one year since migrating issues from
# 'community' to 'zed' repository. The migration added activity to all

View File

@@ -8,7 +8,6 @@ on:
jobs:
deploy-docs:
name: Deploy Docs
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
steps:

View File

@@ -11,7 +11,6 @@ on:
jobs:
check_formatting:
name: "Check formatting"
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
steps:
@@ -30,8 +29,5 @@ jobs:
false
}
- name: Check for Typos with Typos-CLI
uses: crate-ci/typos@v1.24.6
with:
config: ./typos.toml
files: ./docs/
- name: Check spelling
run: script/check-spelling docs/

View File

@@ -21,7 +21,7 @@ jobs:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@82a92a6e8fbeee089604da2575dc567ae9ddeaab # v2
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "github"

View File

@@ -1,5 +1,3 @@
name: Release Actions
on:
release:
types: [published]

View File

@@ -1,5 +1,3 @@
name: Update All Top Ranking Issues
on:
schedule:
- cron: "0 */12 * * *"
@@ -12,14 +10,14 @@ jobs:
steps:
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
- name: Set up uv
uses: astral-sh/setup-uv@f3bcaebff5eace81a1c062af9f9011aae482ca9d # v3
uses: astral-sh/setup-uv@v3
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
- name: Install Python 3.13
run: uv python install 3.13
- name: Install Python 3.12
run: uv python install 3.12
- name: Install dependencies
run: uv sync --project script/update_top_ranking_issues -p 3.13
run: uv sync --project script/update_top_ranking_issues -p 3.12
- name: Run script
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393

View File

@@ -1,5 +1,3 @@
name: Update Weekly Top Ranking Issues
on:
schedule:
- cron: "0 15 * * *"
@@ -12,14 +10,14 @@ jobs:
steps:
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
- name: Set up uv
uses: astral-sh/setup-uv@f3bcaebff5eace81a1c062af9f9011aae482ca9d # v3
uses: astral-sh/setup-uv@v3
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
- name: Install Python 3.13
run: uv python install 3.13
- name: Install Python 3.12
run: uv python install 3.12
- name: Install dependencies
run: uv sync --project script/update_top_ranking_issues -p 3.13
run: uv sync --project script/update_top_ranking_issues -p 3.12
- name: Run script
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7

666
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,7 @@ members = [
"crates/indexed_docs",
"crates/inline_completion_button",
"crates/install_cli",
"crates/isahc_http_client",
"crates/journal",
"crates/language",
"crates/language_model",
@@ -87,7 +88,6 @@ members = [
"crates/remote",
"crates/remote_server",
"crates/repl",
"crates/reqwest_client",
"crates/rich_text",
"crates/rope",
"crates/rpc",
@@ -122,7 +122,6 @@ members = [
"crates/ui",
"crates/ui_input",
"crates/ui_macros",
"crates/reqwest_client",
"crates/util",
"crates/vcs_menu",
"crates/vim",
@@ -156,12 +155,15 @@ members = [
"extensions/proto",
"extensions/purescript",
"extensions/ruff",
"extensions/ruby",
"extensions/slash-commands-example",
"extensions/snippets",
"extensions/svelte",
"extensions/terraform",
"extensions/test-extension",
"extensions/toml",
"extensions/uiua",
"extensions/vue",
"extensions/zig",
#
@@ -217,7 +219,7 @@ git = { path = "crates/git" }
git_hosting_providers = { path = "crates/git_hosting_providers" }
go_to_line = { path = "crates/go_to_line" }
google_ai = { path = "crates/google_ai" }
gpui = { path = "crates/gpui", default-features = false, features = ["http_client"]}
gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
headless = { path = "crates/headless" }
html_to_markdown = { path = "crates/html_to_markdown" }
@@ -226,6 +228,7 @@ image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" }
isahc_http_client = { path = "crates/isahc_http_client" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_model = { path = "crates/language_model" }
@@ -262,7 +265,6 @@ release_channel = { path = "crates/release_channel" }
remote = { path = "crates/remote" }
remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
@@ -324,7 +326,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
async-recursion = "1.0.0"
async-tar = "0.5.0"
async-trait = "0.1"
async-tungstenite = "0.24"
async-tungstenite = "0.23"
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
base64 = "0.22"
@@ -333,7 +335,6 @@ blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = "0.18"
cargo_toml = "0.20"
chrono = { version = "0.4", features = ["serde"] }
@@ -347,7 +348,6 @@ ctor = "0.2.6"
dashmap = "6.0"
derive_more = "0.99.17"
dirs = "4.0"
ec4rs = "1.1"
emojis = "0.6.1"
env_logger = "0.11"
exec = "0.3.1"
@@ -366,6 +366,10 @@ ignore = "0.4.22"
image = "0.25.1"
indexmap = { version = "1.6.2", features = ["serde"] }
indoc = "2"
# We explicitly disable http2 support in isahc.
isahc = { version = "1.7.2", default-features = false, features = [
"text-decoding",
] }
itertools = "0.13.0"
jsonwebtoken = "9.3"
libc = "0.2"
@@ -390,14 +394,6 @@ pulldown-cmark = { version = "0.12.0", default-features = false }
rand = "0.8.5"
regex = "1.5"
repair_json = "0.1.0"
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29", default-features = false, features = [
"charset",
"http2",
"macos-system-configuration",
"rustls-tls-native-roots",
"socks",
"stream",
] }
rsa = "0.9.6"
runtimelib = { version = "0.15", default-features = false, features = [
"async-dispatcher-runtime",
@@ -442,7 +438,7 @@ time = { version = "0.3", features = [
] }
tiny_http = "0.8"
toml = "0.8"
tokio = { version = "1" }
tokio = { version = "1", features = ["full"] }
tower-http = "0.4.4"
tree-sitter = { version = "0.23", features = ["wasm"] }
tree-sitter-bash = "0.23"
@@ -455,7 +451,6 @@ tree-sitter-go = "0.23"
tree-sitter-go-mod = { git = "https://github.com/zed-industries/tree-sitter-go-mod", rev = "a9aea5e358cde4d0f8ff20b7bc4fa311e359c7ca", package = "tree-sitter-gomod" }
tree-sitter-gowork = { git = "https://github.com/zed-industries/tree-sitter-go-work", rev = "acb0617bf7f4fda02c6217676cc64acb89536dc7" }
tree-sitter-heex = { git = "https://github.com/zed-industries/tree-sitter-heex", rev = "1dd45142fbb05562e35b2040c6129c9bca346592" }
tree-sitter-diff = "0.1.0"
tree-sitter-html = "0.20"
tree-sitter-jsdoc = "0.23"
tree-sitter-json = "0.23"
@@ -483,11 +478,9 @@ wasmtime = { version = "24", default-features = false, features = [
wasmtime-wasi = "24"
which = "6.0.0"
wit-component = "0.201"
zstd = "0.11"
[workspace.dependencies.async-stripe]
git = "https://github.com/zed-industries/async-stripe"
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
version = "0.39"
default-features = false
features = [
"runtime-tokio-hyper-rustls",

View File

@@ -1,2 +0,0 @@
[build]
dockerfile = "Dockerfile-cross"

View File

@@ -13,9 +13,30 @@ ARG GITHUB_SHA
ENV GITHUB_SHA=$GITHUB_SHA
# At some point in the past 3 weeks, additional dependencies on `xkbcommon` and
# `xkbcommon-x11` were introduced into collab.
#
# A `git bisect` points to this commit as being the culprit: `b8e6098f60e5dabe98fe8281f993858dacc04a55`.
#
# Now when we try to build collab for the Docker image, it fails with the following
# error:
#
# ```
# 985.3 = note: /usr/bin/ld: cannot find -lxkbcommon: No such file or directory
# 985.3 /usr/bin/ld: cannot find -lxkbcommon-x11: No such file or directory
# 985.3 collect2: error: ld returned 1 exit status
# ```
#
# The last successful deploys were at:
# - Staging: `4f408ec65a3867278322a189b4eb20f1ab51f508`
# - Production: `fc4c533d0a8c489e5636a4249d2b52a80039fbd7`
#
# Also add `cmake`, since we need it to build `wasmtime`.
#
# Installing these as a temporary workaround, but I think ideally we'd want to figure
# out what caused them to be included in the first place.
RUN apt-get update; \
apt-get install -y --no-install-recommends cmake
apt-get install -y --no-install-recommends libxkbcommon-dev libxkbcommon-x11-dev cmake
RUN --mount=type=cache,target=./script/node_modules \
--mount=type=cache,target=/usr/local/cargo/registry \

View File

@@ -1,17 +0,0 @@
# syntax=docker/dockerfile:1
ARG CROSS_BASE_IMAGE
FROM ${CROSS_BASE_IMAGE}
WORKDIR /app
ARG TZ=Etc/UTC \
LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive
ENV CARGO_TERM_COLOR=always
COPY script/install-mold script/
RUN ./script/install-mold "2.34.0"
COPY script/remote-server script/
RUN ./script/remote-server
COPY . .

View File

@@ -1,16 +0,0 @@
.git
.github
**/.gitignore
**/.gitkeep
.gitattributes
.mailmap
**/target
zed.xcworkspace
.DS_Store
compose.yml
plugins/bin
script/node_modules
styles/node_modules
crates/collab/static/styles.css
vendor/bin
assets/themes/

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-diff"><path d="M12 3v14"/><path d="M5 10h14"/><path d="M5 21h14"/></svg>

Before

Width:  |  Height:  |  Size: 275 B

View File

@@ -20,7 +20,6 @@
"bashrc": "terminal",
"bmp": "image",
"c": "c",
"c++": "cpp",
"cc": "cpp",
"cjs": "javascript",
"coffee": "coffeescript",
@@ -28,7 +27,6 @@
"cpp": "cpp",
"css": "css",
"csv": "storage",
"cxx": "cpp",
"cts": "typescript",
"dart": "dart",
"dat": "storage",
@@ -68,13 +66,11 @@
"heex": "elixir",
"heic": "image",
"heif": "image",
"hh": "cpp",
"hpp": "cpp",
"hrl": "erlang",
"hs": "haskell",
"htm": "template",
"html": "template",
"hxx": "cpp",
"ib": "storage",
"ico": "image",
"ini": "settings",
@@ -128,7 +124,6 @@
"php": "php",
"plist": "template",
"png": "image",
"postcss": "css",
"ppt": "document",
"pptx": "document",
"prettierignore": "prettier",

View File

@@ -395,7 +395,6 @@
// Change the default action on `menu::Confirm` by setting the parameter
// "alt-cmd-o": ["projects::OpenRecent", {"create_new_window": true }],
"alt-cmd-o": "projects::OpenRecent",
"ctrl-cmd-o": "projects::OpenRemote",
"alt-cmd-b": "branches::OpenRecent",
"ctrl-~": "workspace::NewTerminal",
"cmd-s": "workspace::Save",

View File

@@ -34,7 +34,7 @@
"cmd-]": "pane::GoForward",
"alt-f7": "editor::FindAllReferences",
"cmd-alt-f7": "editor::FindAllReferences",
"cmd-b": "editor::GoToDefinition", // Conflicts with workspace::ToggleLeftDock
"cmd-b": "editor::GoToDefinition",
"cmd-alt-b": "editor::GoToDefinitionSplit",
"cmd-shift-b": "editor::GoToTypeDefinition",
"cmd-alt-shift-b": "editor::GoToTypeDefinitionSplit",
@@ -64,8 +64,7 @@
"cmd-shift-o": "file_finder::Toggle",
"cmd-shift-a": "command_palette::Toggle",
"shift shift": "command_palette::Toggle",
"cmd-alt-o": "project_symbols::Toggle", // JetBrains: Go to Symbol
"cmd-o": "project_symbols::Toggle", // JetBrains: Go to Class
"cmd-alt-o": "project_symbols::Toggle",
"cmd-1": "workspace::ToggleLeftDock",
"cmd-6": "diagnostics::Deploy"
}

View File

@@ -128,10 +128,6 @@
"shift-m": "vim::WindowMiddle",
"shift-l": "vim::WindowBottom",
// z commands
"z enter": ["workspace::SendKeystrokes", "z t ^"],
"z -": ["workspace::SendKeystrokes", "z b ^"],
"z ^": ["workspace::SendKeystrokes", "shift-h k z b ^"],
"z +": ["workspace::SendKeystrokes", "shift-l j z t ^"],
"z t": "editor::ScrollCursorTop",
"z z": "editor::ScrollCursorCenter",
"z .": ["workspace::SendKeystrokes", "z z ^"],
@@ -256,7 +252,6 @@
"@": ["vim::PushOperator", "ReplayRegister"],
"ctrl-pagedown": "pane::ActivateNextItem",
"ctrl-pageup": "pane::ActivatePrevItem",
"insert": "vim::InsertBefore",
// tree-sitter related commands
"[ x": "editor::SelectLargerSyntaxNode",
"] x": "editor::SelectSmallerSyntaxNode",
@@ -339,8 +334,7 @@
"ctrl-t": "vim::Indent",
"ctrl-d": "vim::Outdent",
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
"ctrl-r": ["vim::PushOperator", "Register"],
"insert": "vim::ToggleReplace"
"ctrl-r": ["vim::PushOperator", "Register"]
}
},
{
@@ -359,8 +353,7 @@
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
"backspace": "vim::UndoReplace",
"tab": "vim::Tab",
"enter": "vim::Enter",
"insert": "vim::InsertBefore"
"enter": "vim::Enter"
}
},
{

View File

@@ -1,33 +1,85 @@
<task_description>
The user of a code editor wants to make a change to their codebase.
You must describe the change using the following XML structure:
# Code Change Workflow
- <patch> - A group of related code changes.
Child tags:
- <title> (required) - A high-level description of the changes. This should be as short
as possible, possibly using common abbreviations.
- <edit> (1 or more) - An edit to make at a particular range within a file.
Includes the following child tags:
- <path> (required) - The path to the file that will be changed.
- <description> (optional) - An arbitrarily-long comment that describes the purpose
of this edit.
- <old_text> (optional) - An excerpt from the file's current contents that uniquely
identifies a range within the file where the edit should occur. If this tag is not
specified, then the entire file will be used as the range.
- <new_text> (required) - The new text to insert into the file.
- <operation> (required) - The type of change that should occur at the given range
of the file. Must be one of the following values:
- `update`: Replaces the entire range with the new text.
- `insert_before`: Inserts the new text before the range.
- `insert_after`: Inserts new text after the range.
- `create`: Creates a new file with the given path and the new text.
- `delete`: Deletes the specified range from the file.
Your task is to guide the user through code changes using a series of steps. Each step should describe a high-level change, which can consist of multiple edits to distinct locations in the codebase.
## Output Example
Provide output as XML, with the following format:
<step>
Update the Person struct to store an age
```rust
struct Person {
// existing fields...
age: u8,
height: f32,
// existing fields...
}
impl Person {
fn age(&self) -> u8 {
self.age
}
}
```
<edit>
<path>src/person.rs</path>
<operation>insert_before</operation>
<search>height: f32,</search>
<description>Add the age field</description>
</edit>
<edit>
<path>src/person.rs</path>
<operation>insert_after</operation>
<search>impl Person {</search>
<description>Add the age getter</description>
</edit>
</step>
## Output Format
First, each `<step>` must contain a written description of the change that should be made. The description should begin with a high-level overview, and can contain markdown code blocks as well. The description should be self-contained and actionable.
After the description, each `<step>` must contain one or more `<edit>` tags, each of which refer to a specific range in a source file. Each `<edit>` tag must contain the following child tags:
### `<path>` (required)
This tag contains the path to the file that will be changed. It can be an existing path, or a path that should be created.
### `<search>` (optional)
This tag contains a search string to locate in the source file, e.g. `pub fn baz() {`. If not provided, the new content will be inserted at the top of the file. Make sure to produce a string that exists in the source file and that isn't ambiguous. When there's ambiguity, add more lines to the search to eliminate it.
### `<description>` (required)
This tag contains a single-line description of the edit that should be made at the given location.
### `<operation>` (required)
This tag indicates what type of change should be made, relative to the given location. It can be one of the following:
- `update`: Rewrites the specified string entirely based on the given description.
- `create`: Creates a new file with the given path based on the provided description.
- `insert_before`: Inserts new text based on the given description before the specified search string.
- `insert_after`: Inserts new text based on the given description after the specified search string.
- `delete`: Deletes the specified string from the containing file.
<guidelines>
- Never provide multiple edits whose ranges intersect each other. Instead, merge them into one edit.
- Prefer multiple edits to smaller, disjoint ranges, rather than one edit to a larger range.
- There's no need to escape angle brackets within XML tags.
- There's no need to describe *what* to do, just *where* to do it.
- Only reference locations that actually exist (unless you're creating a file).
- If creating a file, assume any subsequent updates are included at the time of creation.
- Don't create and then update a file. Always create new files in one hot.
- Prefer multiple edits to smaller regions, as opposed to one big edit to a larger region.
- Don't produce edits that intersect each other. In that case, merge them into a bigger edit.
- Never nest an edit with another edit. Never include CDATA. All edits are leaf nodes.
- Descriptions are required for all edits except delete.
- When generating multiple edits, ensure the descriptions are specific to each individual operation.
- Avoid referring to the search string in the description. Focus on the change to be made, not the location where it's made. That's implicit with the `search` string you provide.
- Don't generate multiple edits at the same location. Instead, combine them together in a single edit with a succinct combined description.
- Always ensure imports are added if you're referencing symbols that are not in scope.
</guidelines>
@@ -72,137 +124,189 @@ Update all shapes to store their origin as an (x, y) tuple and implement Display
<message role="assistant">
We'll need to update both the rectangle and circle modules.
<patch>
<title>Add origins and display impls to shapes</title>
<edit>
<path>src/shapes/rectangle.rs</path>
<description>Add the origin field to Rectangle struct</description>
<operation>insert_after</operation>
<old_text>
pub struct Rectangle {
</old_text>
<new_text>
origin: (f64, f64),
</new_text>
</edit>
<step>
Add origin fields to both shape types.
<edit>
<path>src/shapes/rectangle.rs</path>
<description>Update the Rectangle's new function to take an origin parameter</description>
<operation>update</operation>
<old_text>
fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
</old_text>
<new_text>
fn new(origin: (f64, f64), width: f64, height: f64) -> Self {
Rectangle { origin, width, height }
}
</new_text>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<description>Add the origin field to Circle struct</description>
<operation>insert_after</operation>
<old_text>
pub struct Circle {
radius: f64,
</old_text>
<new_text>
```rust
struct Rectangle {
// existing fields ...
origin: (f64, f64),
</new_text>
}
```
```rust
struct Circle {
// existing fields ...
origin: (f64, f64),
}
```
<edit>
<path>src/shapes/rectangle.rs</path>
<operation>insert_before</operation>
<search>
width: f64,
height: f64,
</search>
<description>Add the origin field to Rectangle</description>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<description>Update the Circle's new function to take an origin parameter</description>
<operation>insert_before</operation>
<search>
radius: f64,
</search>
<description>Add the origin field to Circle</description>
</edit>
<step>
Update both shape's constructors to take an origin.
<edit>
<path>src/shapes/rectangle.rs</path>
<operation>update</operation>
<old_text>
fn new(radius: f64) -> Self {
Circle { radius }
}
</old_text>
<new_text>
fn new(origin: (f64, f64), radius: f64) -> Self {
Circle { origin, radius }
}
</new_text>
<search>
fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
</search>
<description>Update the Rectangle new function to take an origin</description>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<operation>update</operation>
<search>
fn new(radius: f64) -> Self {
Circle { radius }
}
</search>
<description>Update the Circle new function to take an origin</description>
</edit>
</step>
<step>
Implement Display for both shapes
<edit>
<path>src/shapes/rectangle.rs</path>
<description>Add an import for the std::fmt module</description>
<operation>insert_before</operation>
<old_text>
<search>
struct Rectangle {
</old_text>
<new_text>
use std::fmt;
</new_text>
</search>
<description>Add an import for the `std::fmt` module</description>
</edit>
<edit>
<path>src/shapes/rectangle.rs</path>
<description>Add a Display implementation for Rectangle</description>
<operation>insert_after</operation>
<old_text>
<search>
Rectangle { width, height }
}
}
</old_text>
<new_text>
impl fmt::Display for Rectangle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.format_struct(f, "Rectangle")
.field("origin", &self.origin)
.field("width", &self.width)
.field("height", &self.height)
.finish()
}
}
</new_text>
</search>
<description>Add a Display implementation for Rectangle</description>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<description>Add an import for the `std::fmt` module</description>
<operation>insert_before</operation>
<old_text>
<search>
struct Circle {
</old_text>
<new_text>
use std::fmt;
</new_text>
</search>
<description>Add an import for the `std::fmt` module</description>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<description>Add a Display implementation for Circle</description>
<operation>insert_after</operation>
<old_text>
<search>
Circle { radius }
}
}
</old_text>
<new_text>
impl fmt::Display for Rectangle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.format_struct(f, "Rectangle")
.field("origin", &self.origin)
.field("width", &self.width)
.field("height", &self.height)
.finish()
}
}
</new_text>
</search>
<description>Add a Display implementation for Circle</description>
</edit>
</patch>
</step>
</message>
</example>
<example number=2>
<message role="user">
```rs src/user.rs
struct User {
pub name: String,
age: u32,
email: String,
}
impl User {
fn new(name: String, age: u32, email: String) -> Self {
User { name, age, email }
}
pub fn print_info(&self) {
todo!()
}
}
```
Let's print all the user information and delete the email field.
</message>
<message role="assistant">
<step>
Update the 'print_info' method to print user information
```rust
impl User {
// ... other methods ...
pub fn print_info(&self) {
println!("Name: {name}, Age: {age}", name = self.name, age = self.age);
}
}
```
<edit>
<path>src/user.rs</path>
<operation>update</operation>
<search>
pub fn print_info(&self) {
todo!()
}
</search>
<description>Print all the user information</description>
</edit>
</step>
<step>
Remove the 'email' field from the User struct
<edit>
<path>src/user.rs</path>
<operation>delete</operation>
<search>
email: String,
</search>
</edit>
<edit>
<path>src/user.rs</path>
<operation>update</operation>
<symbol>
fn new(name: String, age: u32, email: String) -> Self {
User { name, age, email }
}
</symbol>
<description>Remove email parameter from new method</description>
</edit>
</step>
</message>
</example>
You should think step by step. When possible, produce smaller, coherent logical steps as opposed to one big step that combines lots of heterogeneous edits.
</task_description>

View File

@@ -0,0 +1,496 @@
<overview>
Your task is to map a step from a workflow to locations in source code where code needs to be changed to fulfill that step.
Given a workflow containing background context plus a series of <step> tags, you will resolve *one* of these step tags to resolve to one or more locations in the code.
With each location, you will produce a brief, one-line description of the changes to be made.
<guidelines>
- There's no need to describe *what* to do, just *where* to do it.
- Only reference locations that actually exist (unless you're creating a file).
- If creating a file, assume any subsequent updates are included at the time of creation.
- Don't create and then update a file. Always create new files in shot.
- Prefer updating symbols lower in the syntax tree if possible.
- Never include suggestions on a parent symbol and one of its children in the same suggestions block.
- Never nest an operation with another operation or include CDATA or other content. All suggestions are leaf nodes.
- Descriptions are required for all suggestions except delete.
- When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
- Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
- To add imports respond with a suggestion where the `"symbol"` key is set to `"#imports"`
</guidelines>
</overview>
<examples>
<example>
<workflow_context>
<message role="user">
```rs src/rectangle.rs
struct Rectangle {
width: f64,
height: f64,
}
impl Rectangle {
fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
}
```
We need to add methods to calculate the area and perimeter of the rectangle. Can you help with that?
</message>
<message role="assistant">
Sure, I can help with that!
<step>Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct</step>
<step>Implement the 'Display' trait for the Rectangle struct</step>
</message>
</workflow_context>
<step_to_resolve>
Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct
</step_to_resolve>
<incorrect_output reason="NEVER append multiple children at the same location.">
{
"title": "Add Rectangle methods",
"suggestions": [
{
"kind": "AppendChild",
"path": "src/shapes.rs",
"symbol": "impl Rectangle",
"description": "Add calculate_area method"
},
{
"kind": "AppendChild",
"path": "src/shapes.rs",
"symbol": "impl Rectangle",
"description": "Add calculate_perimeter method"
}
]
}
</incorrect_output>
<correct_output>
{
"title": "Add Rectangle methods",
"suggestions": [
{
"kind": "AppendChild",
"path": "src/shapes.rs",
"symbol": "impl Rectangle",
"description": "Add calculate area and perimeter methods"
}
]
}
</correct_output>
<step_to_resolve>
Implement the 'Display' trait for the Rectangle struct
</step_to_resolve>
<output>
{
"title": "Implement Display for Rectangle",
"suggestions": [
{
"kind": "InsertSiblingAfter",
"path": "src/shapes.rs",
"symbol": "impl Rectangle",
"description": "Implement Display trait for Rectangle"
}
]
}
</output>
<example>
<workflow_context>
<message role="user">
```rs src/user.rs
struct User {
pub name: String,
age: u32,
email: String,
}
impl User {
fn new(name: String, age: u32, email: String) -> Self {
User { name, age, email }
}
pub fn print_info(&self) {
println!("Name: {}, Age: {}, Email: {}", self.name, self.age, self.email);
}
}
```
</message>
<message role="assistant">
Certainly!
<step>Update the 'print_info' method to use formatted output</step>
<step>Remove the 'email' field from the User struct</step>
</message>
</workflow_context>
<step_to_resolve>
Update the 'print_info' method to use formatted output
</step_to_resolve>
<output>
{
"title": "Use formatted output",
"suggestions": [
{
"kind": "Update",
"path": "src/user.rs",
"symbol": "impl User pub fn print_info",
"description": "Use formatted output"
}
]
}
</output>
<step_to_resolve>
Remove the 'email' field from the User struct
</step_to_resolve>
<output>
{
"title": "Remove email field",
"suggestions": [
{
"kind": "Delete",
"path": "src/user.rs",
"symbol": "struct User email"
}
]
}
</output>
</example>
<example>
<workflow_context>
<message role="user">
```rs src/vehicle.rs
struct Vehicle {
make: String,
model: String,
year: u32,
}
impl Vehicle {
fn new(make: String, model: String, year: u32) -> Self {
Vehicle { make, model, year }
}
fn print_year(&self) {
println!("Year: {}", self.year);
}
}
```
</message>
<message role="assistant">
<step>Add a 'use std::fmt;' statement at the beginning of the file</step>
<step>Add a new method 'start_engine' in the Vehicle impl block</step>
</message>
</workflow_context>
<step_to_resolve>
Add a 'use std::fmt;' statement at the beginning of the file
</step_to_resolve>
<output>
{
"title": "Add use std::fmt statement",
"suggestions": [
{
"kind": "PrependChild",
"path": "src/vehicle.rs",
"symbol": "#imports",
"description": "Add 'use std::fmt' statement"
}
]
}
</output>
<step_to_resolve>
Add a new method 'start_engine' in the Vehicle impl block
</step_to_resolve>
<output>
{
"title": "Add start_engine method",
"suggestions": [
{
"kind": "InsertSiblingAfter",
"path": "src/vehicle.rs",
"symbol": "impl Vehicle fn new",
"description": "Add start_engine method"
}
]
}
</output>
</example>
<example>
<workflow_context>
<message role="user">
```rs src/employee.rs
struct Employee {
name: String,
position: String,
salary: u32,
department: String,
}
impl Employee {
fn new(name: String, position: String, salary: u32, department: String) -> Self {
Employee { name, position, salary, department }
}
fn print_details(&self) {
println!("Name: {}, Position: {}, Salary: {}, Department: {}",
self.name, self.position, self.salary, self.department);
}
fn give_raise(&mut self, amount: u32) {
self.salary += amount;
}
}
```
</message>
<message role="assistant">
<step>Make salary an f32</step>
<step>Remove the 'department' field and update the 'print_details' method</step>
</message>
</workflow_context>
<step_to_resolve>
Make salary an f32
</step_to_resolve>
<incorrect_output reason="NEVER include suggestions on a parent symbol and one of its children in the same suggestions block.">
{
"title": "Change salary to f32",
"suggestions": [
{
"kind": "Update",
"path": "src/employee.rs",
"symbol": "struct Employee",
"description": "Change the type of salary to an f32"
},
{
"kind": "Update",
"path": "src/employee.rs",
"symbol": "struct Employee salary",
"description": "Change the type to an f32"
}
]
}
</incorrect_output>
<correct_output>
{
"title": "Change salary to f32",
"suggestions": [
{
"kind": "Update",
"path": "src/employee.rs",
"symbol": "struct Employee salary",
"description": "Change the type to an f32"
}
]
}
</correct_output>
<step_to_resolve>
Remove the 'department' field and update the 'print_details' method
</step_to_resolve>
<output>
{
"title": "Remove department",
"suggestions": [
{
"kind": "Delete",
"path": "src/employee.rs",
"symbol": "struct Employee department"
},
{
"kind": "Update",
"path": "src/employee.rs",
"symbol": "impl Employee fn print_details",
"description": "Don't print the 'department' field"
}
]
}
</output>
</example>
<example>
<workflow_context>
<message role="user">
```rs src/game.rs
struct Player {
name: String,
health: i32,
pub score: u32,
}
impl Player {
pub fn new(name: String) -> Self {
Player { name, health: 100, score: 0 }
}
}
struct Game {
players: Vec<Player>,
}
impl Game {
fn new() -> Self {
Game { players: Vec::new() }
}
}
```
</message>
<message role="assistant">
<step>Add a 'level' field to Player and update the 'new' method</step>
</message>
</workflow_context>
<step_to_resolve>
Add a 'level' field to Player and update the 'new' method
</step_to_resolve>
<output>
{
"title": "Add level field to Player",
"suggestions": [
{
"kind": "InsertSiblingAfter",
"path": "src/game.rs",
"symbol": "struct Player pub score",
"description": "Add level field to Player"
},
{
"kind": "Update",
"path": "src/game.rs",
"symbol": "impl Player pub fn new",
"description": "Initialize level in new method"
}
]
}
</output>
</example>
<example>
<workflow_context>
<message role="user">
```rs src/config.rs
use std::collections::HashMap;
struct Config {
settings: HashMap<String, String>,
}
impl Config {
fn new() -> Self {
Config { settings: HashMap::new() }
}
}
```
</message>
<message role="assistant">
<step>Add a 'load_from_file' method to Config and import necessary modules</step>
</message>
</workflow_context>
<step_to_resolve>
Add a 'load_from_file' method to Config and import necessary modules
</step_to_resolve>
<output>
{
"title": "Add load_from_file method",
"suggestions": [
{
"kind": "PrependChild",
"path": "src/config.rs",
"symbol": "#imports",
"description": "Import std::fs and std::io modules"
},
{
"kind": "AppendChild",
"path": "src/config.rs",
"symbol": "impl Config",
"description": "Add load_from_file method"
}
]
}
</output>
</example>
<example>
<workflow_context>
<message role="user">
```rs src/database.rs
pub(crate) struct Database {
connection: Connection,
}
impl Database {
fn new(url: &str) -> Result<Self, Error> {
let connection = Connection::connect(url)?;
Ok(Database { connection })
}
async fn query(&self, sql: &str) -> Result<Vec<Row>, Error> {
self.connection.query(sql, &[])
}
}
```
</message>
<message role="assistant">
<step>Add error handling to the 'query' method and create a custom error type</step>
</message>
</workflow_context>
<step_to_resolve>
Add error handling to the 'query' method and create a custom error type
</step_to_resolve>
<output>
{
"title": "Add error handling to query",
"suggestions": [
{
"kind": "PrependChild",
"path": "src/database.rs",
"description": "Import necessary error handling modules"
},
{
"kind": "InsertSiblingBefore",
"path": "src/database.rs",
"symbol": "pub(crate) struct Database",
"description": "Define custom DatabaseError enum"
},
{
"kind": "Update",
"path": "src/database.rs",
"symbol": "impl Database async fn query",
"description": "Implement error handling in query method"
}
]
}
</output>
</example>
</examples>
Now generate the suggestions for the following step:
<workflow_context>
{{{workflow_context}}}
</workflow_context>
<step_to_resolve>
{{{step_to_resolve}}}
</step_to_resolve>

View File

@@ -27,14 +27,14 @@
"inline_completion_provider": "copilot"
},
// The name of a font to use for rendering text in the editor
"buffer_font_family": "Zed Plex Mono",
"buffer_font_family": "IBM Plex Mono",
// Set the buffer text's font fallbacks, this will be merged with
// the platform's default fallbacks.
"buffer_font_fallbacks": null,
// The OpenType features to enable for text in the editor.
"buffer_font_features": {
// Disable ligatures:
// "calt": false
"calt": false
},
// The default font size for text in the editor
"buffer_font_size": 15,
@@ -53,7 +53,7 @@
"buffer_line_height": "comfortable",
// The name of a font to use for rendering text in the UI
// You can set this to ".SystemUIFont" to use the system font
"ui_font_family": "Zed Plex Sans",
"ui_font_family": "IBM Plex Sans",
// Set the UI's font fallbacks, this will be merged with the platform's
// default font fallbacks.
"ui_font_fallbacks": null,
@@ -494,14 +494,7 @@
// Position of the close button on the editor tabs.
"close_position": "right",
// Whether to show the file icon for a tab.
"file_icons": false,
// What to do after closing the current tab.
//
// 1. Activate the tab that was open previously (default)
// "History"
// 2. Activate the neighbour tab (prefers the right one, if present)
// "Neighbour"
"activate_on_close": "history"
"file_icons": false
},
// Settings related to preview tabs.
"preview_tabs": {
@@ -712,10 +705,10 @@
// May take 2 values:
// 1. Rely on default platform handling of option key, on macOS
// this means generating certain unicode characters
// "option_as_meta": false,
// "option_to_meta": false,
// 2. Make the option keys behave as a 'meta' key, e.g. for emacs
// "option_as_meta": true,
"option_as_meta": false,
// "option_to_meta": true,
"option_as_meta": true,
// Whether or not selecting text in the terminal will automatically
// copy to the system clipboard.
"copy_on_select": false,
@@ -760,7 +753,7 @@
// "font_size": 15,
// Set the terminal's font family. If this option is not included,
// the terminal will default to matching the buffer's font family.
// "font_family": "Zed Plex Mono",
// "font_family": "IBM Plex Mono",
// Set the terminal's font fallbacks. If this option is not included,
// the terminal will default to matching the buffer's font fallbacks.
// This will be merged with the platform's default font fallbacks
@@ -824,7 +817,6 @@
// Different settings for specific languages.
"languages": {
"Astro": {
"language_servers": ["astro-language-server", "..."],
"prettier": {
"allowed": true,
"plugins": ["prettier-plugin-astro"]
@@ -851,10 +843,6 @@
"Dart": {
"tab_size": 2
},
"Diff": {
"remove_trailing_whitespace_on_save": false,
"ensure_final_newline_on_save": false
},
"Elixir": {
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
},

View File

@@ -1,7 +0,0 @@
// Server-specific settings
//
// For a full list of overridable settings, and general information on settings,
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
{
"lsp": {}
}

View File

@@ -26,3 +26,6 @@ serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
util.workspace = true
[dev-dependencies]
tokio.workspace = true

View File

@@ -58,7 +58,7 @@ impl Assets {
pub fn load_test_fonts(&self, cx: &AppContext) {
cx.text_system()
.add_fonts(vec![self
.load("fonts/plex-mono/ZedPlexMono-Regular.ttf")
.load("fonts/ibm-plex-mono/IBMPlexMono-Regular.ttf")
.unwrap()
.unwrap()])
.unwrap()

View File

@@ -97,7 +97,6 @@ language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
languages = { workspace = true, features = ["test-support"] }
log.workspace = true
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
serde_json_lenient.workspace = true

View File

@@ -6,7 +6,6 @@ mod context;
pub mod context_store;
mod inline_assistant;
mod model_selector;
mod patch;
mod prompt_library;
mod prompts;
mod slash_command;
@@ -15,6 +14,7 @@ pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
mod tools;
mod workflow;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
@@ -35,13 +35,11 @@ use language_model::{
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
pub use patch::*;
pub use prompts::PromptBuilder;
use prompts::PromptLoadingParams;
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::workflow_command::WorkflowSlashCommand;
use slash_command::{
auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
@@ -52,6 +50,7 @@ use std::path::PathBuf;
use std::sync::Arc;
pub(crate) use streaming_diff::*;
use util::ResultExt;
pub use workflow::*;
use crate::slash_command_settings::SlashCommandSettings;
@@ -394,25 +393,12 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
slash_command_registry.register_command(now_command::NowSlashCommand, false);
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
if let Some(prompt_builder) = prompt_builder {
cx.observe_global::<SettingsStore>({
let slash_command_registry = slash_command_registry.clone();
let prompt_builder = prompt_builder.clone();
move |cx| {
if AssistantSettings::get_global(cx).are_live_diffs_enabled(cx) {
slash_command_registry.register_command(
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
true,
);
} else {
slash_command_registry.unregister_command_by_name(WorkflowSlashCommand::NAME);
}
}
})
.detach();
slash_command_registry.register_command(
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
true,
);
cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
let slash_command_registry = slash_command_registry.clone();
move |is_enabled, _cx| {

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@ use std::sync::Arc;
use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use feature_flags::FeatureFlagAppExt;
use fs::Fs;
use gpui::{AppContext, Pixels};
use language_model::provider::open_ai;
@@ -62,13 +61,6 @@ pub struct AssistantSettings {
pub default_model: LanguageModelSelection,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
}
impl AssistantSettings {
pub fn are_live_diffs_enabled(&self, cx: &AppContext) -> bool {
cx.is_staff() || self.enable_experimental_live_diffs
}
}
/// Assistant panel settings
@@ -246,7 +238,6 @@ impl AssistantSettingsContent {
}
}),
inline_alternatives: None,
enable_experimental_live_diffs: None,
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
@@ -266,7 +257,6 @@ impl AssistantSettingsContent {
.to_string(),
}),
inline_alternatives: None,
enable_experimental_live_diffs: None,
},
}
}
@@ -383,7 +373,6 @@ impl Default for VersionedAssistantSettingsContent {
default_height: None,
default_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
})
}
}
@@ -414,10 +403,6 @@ pub struct AssistantSettingsContentV2 {
default_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
///
/// Default: false
enable_experimental_live_diffs: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -540,10 +525,7 @@ impl Settings for AssistantSettings {
);
merge(&mut settings.default_model, value.default_model);
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
value.enable_experimental_live_diffs,
);
// merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
}
Ok(settings)
@@ -602,7 +584,6 @@ mod tests {
dock: None,
default_width: None,
default_height: None,
enable_experimental_live_diffs: None,
}),
)
},

View File

@@ -2,8 +2,8 @@
mod context_tests;
use crate::{
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantEdit, AssistantPatch,
AssistantPatchStatus, MessageId, MessageStatus,
prompts::PromptBuilder, slash_command::SlashCommandLine, MessageId, MessageStatus,
WorkflowStep, WorkflowStepEdit, WorkflowStepResolution, WorkflowSuggestionGroup,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@@ -15,15 +15,17 @@ use clock::ReplicaId;
use collections::{HashMap, HashSet};
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use fs::{Fs, RemoveOptions};
use futures::{future::Shared, FutureExt, StreamExt};
use futures::{
future::{self, Shared},
FutureExt, StreamExt,
};
use gpui::{
AppContext, Context as _, EventEmitter, Model, ModelContext, RenderImage, SharedString,
Subscription, Task,
AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, RenderImage,
SharedString, Subscription, Task,
};
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
@@ -35,7 +37,7 @@ use project::Project;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{
cmp::{max, Ordering},
cmp::{self, max, Ordering},
fmt::Debug,
iter, mem,
ops::Range,
@@ -292,12 +294,10 @@ impl ContextOperation {
#[derive(Debug, Clone)]
pub enum ContextEvent {
ShowAssistError(SharedString),
ShowPaymentRequiredError,
ShowMaxMonthlySpendReachedError,
MessagesEdited,
SummaryChanged,
StreamedCompletion,
PatchesUpdated {
WorkflowStepsUpdated {
removed: Vec<Range<language::Anchor>>,
updated: Vec<Range<language::Anchor>>,
},
@@ -451,14 +451,13 @@ pub struct XmlTag {
#[derive(Copy, Clone, Debug, strum::EnumString, PartialEq, Eq, strum::AsRefStr)]
#[strum(serialize_all = "snake_case")]
pub enum XmlTagKind {
Patch,
Title,
Step,
Edit,
Path,
Description,
OldText,
NewText,
Search,
Within,
Operation,
Description,
}
pub struct Context {
@@ -488,7 +487,7 @@ pub struct Context {
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>,
patches: Vec<AssistantPatch>,
workflow_steps: Vec<WorkflowStep>,
xml_tags: Vec<XmlTag>,
project: Option<Model<Project>>,
prompt_builder: Arc<PromptBuilder>,
@@ -504,7 +503,7 @@ impl ContextAnnotation for PendingSlashCommand {
}
}
impl ContextAnnotation for AssistantPatch {
impl ContextAnnotation for WorkflowStep {
fn range(&self) -> &Range<language::Anchor> {
&self.range
}
@@ -589,7 +588,7 @@ impl Context {
telemetry,
project,
language_registry,
patches: Vec::new(),
workflow_steps: Vec::new(),
xml_tags: Vec::new(),
prompt_builder,
};
@@ -927,49 +926,48 @@ impl Context {
self.summary.as_ref()
}
pub(crate) fn patch_containing(
pub(crate) fn workflow_step_containing(
&self,
position: Point,
offset: usize,
cx: &AppContext,
) -> Option<&AssistantPatch> {
) -> Option<&WorkflowStep> {
let buffer = self.buffer.read(cx);
let index = self.patches.binary_search_by(|patch| {
let patch_range = patch.range.to_point(&buffer);
if position < patch_range.start {
Ordering::Greater
} else if position > patch_range.end {
Ordering::Less
} else {
Ordering::Equal
}
});
if let Ok(ix) = index {
Some(&self.patches[ix])
} else {
None
}
let index = self
.workflow_steps
.binary_search_by(|step| {
let step_range = step.range.to_offset(&buffer);
if offset < step_range.start {
Ordering::Greater
} else if offset > step_range.end {
Ordering::Less
} else {
Ordering::Equal
}
})
.ok()?;
Some(&self.workflow_steps[index])
}
pub fn patch_ranges(&self) -> impl Iterator<Item = Range<language::Anchor>> + '_ {
self.patches.iter().map(|patch| patch.range.clone())
pub fn workflow_step_ranges(&self) -> impl Iterator<Item = Range<language::Anchor>> + '_ {
self.workflow_steps.iter().map(|step| step.range.clone())
}
pub(crate) fn patch_for_range(
pub(crate) fn workflow_step_for_range(
&self,
range: &Range<language::Anchor>,
cx: &AppContext,
) -> Option<&AssistantPatch> {
) -> Option<&WorkflowStep> {
let buffer = self.buffer.read(cx);
let index = self.patch_index_for_range(range, buffer).ok()?;
Some(&self.patches[index])
let index = self.workflow_step_index_for_range(range, buffer).ok()?;
Some(&self.workflow_steps[index])
}
fn patch_index_for_range(
fn workflow_step_index_for_range(
&self,
tagged_range: &Range<text::Anchor>,
buffer: &text::BufferSnapshot,
) -> Result<usize, usize> {
self.patches
self.workflow_steps
.binary_search_by(|probe| probe.range.cmp(&tagged_range, buffer))
}
@@ -1017,6 +1015,8 @@ impl Context {
language::BufferEvent::Edited => {
self.count_remaining_tokens(cx);
self.reparse(cx);
// Use `inclusive = true` to invalidate a step when an edit occurs
// at the start/end of a parsed step.
cx.emit(ContextEvent::MessagesEdited);
}
_ => {}
@@ -1245,8 +1245,8 @@ impl Context {
let mut removed_slash_command_ranges = Vec::new();
let mut updated_slash_commands = Vec::new();
let mut removed_patches = Vec::new();
let mut updated_patches = Vec::new();
let mut removed_steps = Vec::new();
let mut updated_steps = Vec::new();
while let Some(mut row_range) = row_ranges.next() {
while let Some(next_row_range) = row_ranges.peek() {
if row_range.end >= next_row_range.start {
@@ -1270,11 +1270,11 @@ impl Context {
&mut removed_slash_command_ranges,
cx,
);
self.reparse_patches_in_range(
self.reparse_workflow_steps_in_range(
start..end,
&buffer,
&mut updated_patches,
&mut removed_patches,
&mut updated_steps,
&mut removed_steps,
cx,
);
}
@@ -1286,10 +1286,10 @@ impl Context {
});
}
if !updated_patches.is_empty() || !removed_patches.is_empty() {
cx.emit(ContextEvent::PatchesUpdated {
removed: removed_patches,
updated: updated_patches,
if !updated_steps.is_empty() || !removed_steps.is_empty() {
cx.emit(ContextEvent::WorkflowStepsUpdated {
removed: removed_steps,
updated: updated_steps,
});
}
}
@@ -1351,7 +1351,7 @@ impl Context {
removed.extend(removed_commands.map(|command| command.source_range));
}
fn reparse_patches_in_range(
fn reparse_workflow_steps_in_range(
&mut self,
range: Range<text::Anchor>,
buffer: &BufferSnapshot,
@@ -1366,32 +1366,41 @@ impl Context {
self.xml_tags
.splice(intersecting_tags_range.clone(), new_tags);
// Find which patches intersect the changed range.
let intersecting_patches_range =
self.indices_intersecting_buffer_range(&self.patches, range.clone(), cx);
// Find which steps intersect the changed range.
let intersecting_steps_range =
self.indices_intersecting_buffer_range(&self.workflow_steps, range.clone(), cx);
// Reparse all tags after the last unchanged patch before the change.
// Reparse all tags after the last unchanged step before the change.
let mut tags_start_ix = 0;
if let Some(preceding_unchanged_patch) =
self.patches[..intersecting_patches_range.start].last()
if let Some(preceding_unchanged_step) =
self.workflow_steps[..intersecting_steps_range.start].last()
{
tags_start_ix = match self.xml_tags.binary_search_by(|tag| {
tag.range
.start
.cmp(&preceding_unchanged_patch.range.end, buffer)
.cmp(&preceding_unchanged_step.range.end, buffer)
.then(Ordering::Less)
}) {
Ok(ix) | Err(ix) => ix,
};
}
// Rebuild the patches in the range.
let new_patches = self.parse_patches(tags_start_ix, range.end, buffer, cx);
updated.extend(new_patches.iter().map(|patch| patch.range.clone()));
let removed_patches = self.patches.splice(intersecting_patches_range, new_patches);
// Rebuild the edit suggestions in the range.
let mut new_steps = self.parse_steps(tags_start_ix, range.end, buffer);
if let Some(project) = self.project() {
for step in &mut new_steps {
Self::resolve_workflow_step_internal(step, &project, cx);
}
}
updated.extend(new_steps.iter().map(|step| step.range.clone()));
let removed_steps = self
.workflow_steps
.splice(intersecting_steps_range, new_steps);
removed.extend(
removed_patches
.map(|patch| patch.range)
removed_steps
.map(|step| step.range)
.filter(|range| !updated.contains(&range)),
);
}
@@ -1452,95 +1461,60 @@ impl Context {
tags
}
fn parse_patches(
fn parse_steps(
&mut self,
tags_start_ix: usize,
buffer_end: text::Anchor,
buffer: &BufferSnapshot,
cx: &AppContext,
) -> Vec<AssistantPatch> {
let mut new_patches = Vec::new();
let mut pending_patch = None;
let mut patch_tag_depth = 0;
) -> Vec<WorkflowStep> {
let mut new_steps = Vec::new();
let mut pending_step = None;
let mut edit_step_depth = 0;
let mut tags = self.xml_tags[tags_start_ix..].iter().peekable();
'tags: while let Some(tag) = tags.next() {
if tag.range.start.cmp(&buffer_end, buffer).is_gt() && patch_tag_depth == 0 {
if tag.range.start.cmp(&buffer_end, buffer).is_gt() && edit_step_depth == 0 {
break;
}
if tag.kind == XmlTagKind::Patch && tag.is_open_tag {
patch_tag_depth += 1;
let patch_start = tag.range.start;
let mut edits = Vec::<Result<AssistantEdit>>::new();
let mut patch = AssistantPatch {
range: patch_start..patch_start,
title: String::new().into(),
if tag.kind == XmlTagKind::Step && tag.is_open_tag {
edit_step_depth += 1;
let edit_start = tag.range.start;
let mut edits = Vec::new();
let mut step = WorkflowStep {
range: edit_start..edit_start,
leading_tags_end: tag.range.end,
trailing_tag_start: None,
edits: Default::default(),
status: crate::AssistantPatchStatus::Pending,
resolution: None,
resolution_task: None,
};
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Patch && !tag.is_open_tag {
patch_tag_depth -= 1;
if patch_tag_depth == 0 {
patch.range.end = tag.range.end;
step.trailing_tag_start.get_or_insert(tag.range.start);
// Include the line immediately after this <patch> tag if it's empty.
let patch_end_offset = patch.range.end.to_offset(buffer);
let mut patch_end_chars = buffer.chars_at(patch_end_offset);
if patch_end_chars.next() == Some('\n')
&& patch_end_chars.next().map_or(true, |ch| ch == '\n')
{
let messages = self.messages_for_offsets(
[patch_end_offset, patch_end_offset + 1],
cx,
);
if messages.len() == 1 {
patch.range.end = buffer.anchor_before(patch_end_offset + 1);
}
}
edits.sort_unstable_by(|a, b| {
if let (Ok(a), Ok(b)) = (a, b) {
a.path.cmp(&b.path)
} else {
Ordering::Equal
}
});
patch.edits = edits.into();
patch.status = AssistantPatchStatus::Ready;
new_patches.push(patch);
if tag.kind == XmlTagKind::Step && !tag.is_open_tag {
// step.trailing_tag_start = Some(tag.range.start);
edit_step_depth -= 1;
if edit_step_depth == 0 {
step.range.end = tag.range.end;
step.edits = edits.into();
new_steps.push(step);
continue 'tags;
}
}
if tag.kind == XmlTagKind::Title && tag.is_open_tag {
let content_start = tag.range.end;
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Title && !tag.is_open_tag {
let content_end = tag.range.start;
patch.title =
trimmed_text_in_range(buffer, content_start..content_end)
.into();
break;
}
}
}
if tag.kind == XmlTagKind::Edit && tag.is_open_tag {
let mut path = None;
let mut old_text = None;
let mut new_text = None;
let mut search = None;
let mut operation = None;
let mut description = None;
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Edit && !tag.is_open_tag {
edits.push(AssistantEdit::new(
edits.push(WorkflowStepEdit::new(
path,
operation,
old_text,
new_text,
search,
description,
));
break;
@@ -1549,8 +1523,7 @@ impl Context {
if tag.is_open_tag
&& [
XmlTagKind::Path,
XmlTagKind::OldText,
XmlTagKind::NewText,
XmlTagKind::Search,
XmlTagKind::Operation,
XmlTagKind::Description,
]
@@ -1562,18 +1535,15 @@ impl Context {
if tag.kind == kind && !tag.is_open_tag {
let tag = tags.next().unwrap();
let content_end = tag.range.start;
let content = trimmed_text_in_range(
buffer,
content_start..content_end,
);
let mut content = buffer
.text_for_range(content_start..content_end)
.collect::<String>();
content.truncate(content.trim_end().len());
match kind {
XmlTagKind::Path => path = Some(content),
XmlTagKind::Operation => operation = Some(content),
XmlTagKind::OldText => {
old_text = Some(content).filter(|s| !s.is_empty())
}
XmlTagKind::NewText => {
new_text = Some(content).filter(|s| !s.is_empty())
XmlTagKind::Search => {
search = Some(content).filter(|s| !s.is_empty())
}
XmlTagKind::Description => {
description =
@@ -1588,28 +1558,162 @@ impl Context {
}
}
patch.edits = edits.into();
pending_patch = Some(patch);
pending_step = Some(step);
}
}
if let Some(mut pending_patch) = pending_patch {
let patch_start = pending_patch.range.start.to_offset(buffer);
if let Some(message) = self.message_for_offset(patch_start, cx) {
if message.anchor_range.end == text::Anchor::MAX {
pending_patch.range.end = text::Anchor::MAX;
if let Some(mut pending_step) = pending_step {
pending_step.range.end = text::Anchor::MAX;
new_steps.push(pending_step);
}
new_steps
}
pub fn resolve_workflow_step(
&mut self,
tagged_range: Range<text::Anchor>,
cx: &mut ModelContext<Self>,
) -> Option<()> {
let index = self
.workflow_step_index_for_range(&tagged_range, self.buffer.read(cx))
.ok()?;
let step = &mut self.workflow_steps[index];
let project = self.project.as_ref()?;
step.resolution.take();
Self::resolve_workflow_step_internal(step, project, cx);
None
}
fn resolve_workflow_step_internal(
step: &mut WorkflowStep,
project: &Model<Project>,
cx: &mut ModelContext<'_, Context>,
) {
step.resolution_task = Some(cx.spawn({
let range = step.range.clone();
let edits = step.edits.clone();
let project = project.clone();
|this, mut cx| async move {
let suggestion_groups =
Self::compute_step_resolution(project, edits, &mut cx).await;
this.update(&mut cx, |this, cx| {
let buffer = this.buffer.read(cx).text_snapshot();
let ix = this.workflow_step_index_for_range(&range, &buffer).ok();
if let Some(ix) = ix {
let step = &mut this.workflow_steps[ix];
let resolution = suggestion_groups.map(|suggestion_groups| {
let mut title = String::new();
for mut chunk in buffer.text_for_range(
step.leading_tags_end
..step.trailing_tag_start.unwrap_or(step.range.end),
) {
if title.is_empty() {
chunk = chunk.trim_start();
}
if let Some((prefix, _)) = chunk.split_once('\n') {
title.push_str(prefix);
break;
} else {
title.push_str(chunk);
}
}
WorkflowStepResolution {
title,
suggestion_groups,
}
});
step.resolution = Some(Arc::new(resolution));
cx.emit(ContextEvent::WorkflowStepsUpdated {
removed: vec![],
updated: vec![range],
})
}
})
.ok();
}
}));
}
async fn compute_step_resolution(
project: Model<Project>,
edits: Arc<[Result<WorkflowStepEdit>]>,
cx: &mut AsyncAppContext,
) -> Result<HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>> {
let mut suggestion_tasks = Vec::new();
for edit in edits.iter() {
let edit = edit.as_ref().map_err(|e| anyhow!("{e}"))?;
suggestion_tasks.push(edit.resolve(project.clone(), cx.clone()));
}
// Expand the context ranges of each suggestion and group suggestions with overlapping context ranges.
let suggestions = future::try_join_all(suggestion_tasks).await?;
let mut suggestions_by_buffer = HashMap::default();
for (buffer, suggestion) in suggestions {
suggestions_by_buffer
.entry(buffer)
.or_insert_with(Vec::new)
.push(suggestion);
}
let mut suggestion_groups_by_buffer = HashMap::default();
for (buffer, mut suggestions) in suggestions_by_buffer {
let mut suggestion_groups = Vec::<WorkflowSuggestionGroup>::new();
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
// Sort suggestions by their range so that earlier, larger ranges come first
suggestions.sort_by(|a, b| a.range().cmp(&b.range(), &snapshot));
// Merge overlapping suggestions
suggestions.dedup_by(|a, b| b.try_merge(a, &snapshot));
// Create context ranges for each suggestion
for suggestion in suggestions {
let context_range = {
let suggestion_point_range = suggestion.range().to_point(&snapshot);
let start_row = suggestion_point_range.start.row.saturating_sub(5);
let end_row =
cmp::min(suggestion_point_range.end.row + 5, snapshot.max_point().row);
let start = snapshot.anchor_before(Point::new(start_row, 0));
let end =
snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
start..end
};
if let Some(last_group) = suggestion_groups.last_mut() {
if last_group
.context_range
.end
.cmp(&context_range.start, &snapshot)
.is_ge()
{
// Merge with the previous group if context ranges overlap
last_group.context_range.end = context_range.end;
last_group.suggestions.push(suggestion);
} else {
// Create a new group
suggestion_groups.push(WorkflowSuggestionGroup {
context_range,
suggestions: vec![suggestion],
});
}
} else {
let message_end = buffer.anchor_after(message.offset_range.end - 1);
pending_patch.range.end = message_end;
// Create the first group
suggestion_groups.push(WorkflowSuggestionGroup {
context_range,
suggestions: vec![suggestion],
});
}
} else {
pending_patch.range.end = text::Anchor::MAX;
}
new_patches.push(pending_patch);
suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
}
new_patches
Ok(suggestion_groups_by_buffer)
}
pub fn pending_command_for_position(
@@ -2008,36 +2112,25 @@ impl Context {
let result = stream_completion.await;
this.update(&mut cx, |this, cx| {
let error_message = if let Some(error) = result.as_ref().err() {
if error.is::<PaymentRequiredError>() {
cx.emit(ContextEvent::ShowPaymentRequiredError);
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
let error_message = result
.as_ref()
.err()
.map(|error| error.to_string().trim().to_string());
if let Some(error_message) = error_message.as_ref() {
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
error_message.clone(),
)));
}
this.update_metadata(assistant_message_id, cx, |metadata| {
if let Some(error_message) = error_message.as_ref() {
metadata.status =
MessageStatus::Error(SharedString::from(error_message.clone()));
} else {
let error_message = error.to_string().trim().to_string();
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
error_message.clone(),
)));
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status =
MessageStatus::Error(SharedString::from(error_message.clone()));
});
Some(error_message)
}
} else {
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Done;
});
None
};
}
});
if let Some(telemetry) = this.telemetry.as_ref() {
let language_name = this
@@ -2053,7 +2146,7 @@ impl Context {
model_provider: model.provider_id().to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
language_name,
});
}
@@ -2208,11 +2301,11 @@ impl Context {
let mut updated = Vec::new();
let mut removed = Vec::new();
for range in ranges {
self.reparse_patches_in_range(range, &buffer, &mut updated, &mut removed, cx);
self.reparse_workflow_steps_in_range(range, &buffer, &mut updated, &mut removed, cx);
}
if !updated.is_empty() || !removed.is_empty() {
cx.emit(ContextEvent::PatchesUpdated { removed, updated })
cx.emit(ContextEvent::WorkflowStepsUpdated { removed, updated })
}
}
@@ -2718,24 +2811,6 @@ impl Context {
}
}
fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range<text::Anchor>) -> String {
let mut is_start = true;
let mut content = buffer
.text_for_range(range)
.map(|mut chunk| {
if is_start {
chunk = chunk.trim_start_matches('\n');
if !chunk.is_empty() {
is_start = false;
}
}
chunk
})
.collect::<String>();
content.truncate(content.trim_end().len());
content
}
#[derive(Debug, Default)]
pub struct ContextVersion {
context: clock::Global,

View File

@@ -1,7 +1,8 @@
use super::{AssistantEdit, MessageCacheMetadata};
use super::{MessageCacheMetadata, WorkflowStepEdit};
use crate::{
assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
assistant_panel, prompt_library, slash_command::file_command, CacheStatus, Context,
ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
WorkflowStepEditKind,
};
use anyhow::Result;
use assistant_slash_command::{
@@ -14,7 +15,6 @@ use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::Project;
use rand::prelude::*;
use serde_json::json;
@@ -478,15 +478,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
#[gpui::test]
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(prompt_library::init);
let mut settings_store = cx.update(SettingsStore::test);
cx.update(|cx| {
settings_store
.set_user_settings(
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
cx,
)
.unwrap()
});
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
cx.update(language::init);
cx.update(Project::init_settings);
@@ -528,7 +520,7 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
»",
cx,
);
expect_patches(
expect_steps(
&context,
"
@@ -547,17 +539,17 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
one
two
«
<patch»",
<step»",
cx,
);
expect_patches(
expect_steps(
&context,
"
one
two
<patch",
<step",
&[],
cx,
);
@@ -571,24 +563,36 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
one
two
<patch«>
<step«>
Add a second function
```rust
fn two() {}
```
<edit>»",
cx,
);
expect_patches(
expect_steps(
&context,
"
one
two
«<patch>
«<step>
Add a second function
```rust
fn two() {}
```
<edit>»",
&[&[]],
cx,
);
// The full patch is added
// The full suggestion is added
edit(
&context,
"
@@ -596,46 +600,51 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
one
two
<patch>
<step>
Add a second function
```rust
fn two() {}
```
<edit>«
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn one</search>
<description>add a `two` function</description>
</edit>
</patch>
</step>
also,»",
cx,
);
expect_patches(
expect_steps(
&context,
"
one
two
«<patch>
«<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn one</search>
<description>add a `two` function</description>
</edit>
</patch>
»
</step>»
also,",
&[&[AssistantEdit {
&[&[WorkflowStepEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn one".into(),
new_text: "fn two() {}".into(),
kind: WorkflowStepEditKind::InsertAfter {
search: "fn one".into(),
description: "add a `two` function".into(),
},
}]],
@@ -650,46 +659,51 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
one
two
<patch>
<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>«fn zero»</old_text>
<new_text>
fn two() {}
</new_text>
<search>«fn zero»</search>
<description>add a `two` function</description>
</edit>
</patch>
</step>
also,",
cx,
);
expect_patches(
expect_steps(
&context,
"
one
two
«<patch>
«<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn zero</search>
<description>add a `two` function</description>
</edit>
</patch>
»
</step>»
also,",
&[&[AssistantEdit {
&[&[WorkflowStepEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
kind: WorkflowStepEditKind::InsertAfter {
search: "fn zero".into(),
description: "add a `two` function".into(),
},
}]],
@@ -701,24 +715,27 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
expect_steps(
&context,
"
one
two
<patch>
<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn zero</search>
<description>add a `two` function</description>
</edit>
</patch>
</step>
also,",
&[],
@@ -729,31 +746,33 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
context.update(cx, |context, cx| {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
expect_steps(
&context,
"
one
two
«<patch>
«<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn zero</search>
<description>add a `two` function</description>
</edit>
</patch>
»
</step>»
also,",
&[&[AssistantEdit {
&[&[WorkflowStepEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
kind: WorkflowStepEditKind::InsertAfter {
search: "fn zero".into(),
description: "add a `two` function".into(),
},
}]],
@@ -773,31 +792,33 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx,
)
});
expect_patches(
expect_steps(
&deserialized_context,
"
one
two
«<patch>
«<step>
Add a second function
```rust
fn two() {}
```
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
<search>fn zero</search>
<description>add a `two` function</description>
</edit>
</patch>
»
</step>»
also,",
&[&[AssistantEdit {
&[&[WorkflowStepEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
kind: WorkflowStepEditKind::InsertAfter {
search: "fn zero".into(),
description: "add a `two` function".into(),
},
}]],
@@ -813,58 +834,48 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.executor().run_until_parked();
}
#[track_caller]
fn expect_patches(
fn expect_steps(
context: &Model<Context>,
expected_marked_text: &str,
expected_suggestions: &[&[AssistantEdit]],
expected_suggestions: &[&[WorkflowStepEdit]],
cx: &mut TestAppContext,
) {
let expected_marked_text = expected_marked_text.unindent();
let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
context.update(cx, |context, cx| {
let expected_marked_text = expected_marked_text.unindent();
let (expected_text, expected_ranges) = marked_text_ranges(&expected_marked_text, false);
context.buffer.read_with(cx, |buffer, _| {
assert_eq!(buffer.text(), expected_text);
let ranges = context
.patches
.workflow_steps
.iter()
.map(|entry| entry.range.to_offset(buffer))
.collect::<Vec<_>>();
(
buffer.text(),
ranges,
context
.patches
.iter()
.map(|step| step.edits.clone())
.collect::<Vec<_>>(),
)
})
let marked = generate_marked_text(&expected_text, &ranges, false);
assert_eq!(
marked,
expected_marked_text,
"unexpected suggestion ranges. actual: {ranges:?}, expected: {expected_ranges:?}"
);
let suggestions = context
.workflow_steps
.iter()
.map(|step| {
step.edits
.iter()
.map(|edit| {
let edit = edit.as_ref().unwrap();
WorkflowStepEdit {
path: edit.path.clone(),
kind: edit.kind.clone(),
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(suggestions, expected_suggestions);
});
});
assert_eq!(buffer_text, expected_text);
let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
assert_eq!(actual_marked_text, expected_marked_text);
assert_eq!(
patches
.iter()
.map(|patch| {
patch
.iter()
.map(|edit| {
let edit = edit.as_ref().unwrap();
AssistantEdit {
path: edit.path.clone(),
kind: edit.kind.clone(),
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
expected_suggestions
);
}
}

View File

@@ -82,6 +82,13 @@ pub struct InlineAssistant {
assists: HashMap<InlineAssistId, InlineAssist>,
assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
assist_observations: HashMap<
InlineAssistId,
(
async_watch::Sender<AssistStatus>,
async_watch::Receiver<AssistStatus>,
),
>,
confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
prompt_history: VecDeque<String>,
prompt_builder: Arc<PromptBuilder>,
@@ -89,6 +96,19 @@ pub struct InlineAssistant {
fs: Arc<dyn Fs>,
}
pub enum AssistStatus {
Idle,
Started,
Stopped,
Finished,
}
impl AssistStatus {
pub fn is_done(&self) -> bool {
matches!(self, Self::Stopped | Self::Finished)
}
}
impl Global for InlineAssistant {}
impl InlineAssistant {
@@ -103,6 +123,7 @@ impl InlineAssistant {
assists: HashMap::default(),
assists_by_editor: HashMap::default(),
assist_groups: HashMap::default(),
assist_observations: HashMap::default(),
confirmed_assists: HashMap::default(),
prompt_history: VecDeque::default(),
prompt_builder,
@@ -246,7 +267,7 @@ impl InlineAssistant {
model_provider: model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: buffer.language().map(|language| language.name().to_proto()),
language_name: buffer.language().map(|language| language.name()),
});
}
}
@@ -767,7 +788,7 @@ impl InlineAssistant {
model_provider: model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: language_name.map(|name| name.to_proto()),
language_name,
});
}
}
@@ -814,6 +835,17 @@ impl InlineAssistant {
.insert(assist_id, confirmed_alternative);
}
}
// Remove the assist from the status updates map
self.assist_observations.remove(&assist_id);
}
pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
return false;
};
codegen.update(cx, |this, cx| this.undo(cx));
true
}
fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
@@ -1007,6 +1039,10 @@ impl InlineAssistant {
codegen.start(user_prompt, assistant_panel_context, cx)
})
.log_err();
if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
tx.send(AssistStatus::Started).ok();
}
}
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@@ -1017,6 +1053,25 @@ impl InlineAssistant {
};
assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
tx.send(AssistStatus::Stopped).ok();
}
}
pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
if let Some(assist) = self.assists.get(&assist_id) {
match assist.codegen.read(cx).status(cx) {
CodegenStatus::Idle => InlineAssistStatus::Idle,
CodegenStatus::Pending => InlineAssistStatus::Pending,
CodegenStatus::Done => InlineAssistStatus::Done,
CodegenStatus::Error(_) => InlineAssistStatus::Error,
}
} else if self.confirmed_assists.contains_key(&assist_id) {
InlineAssistStatus::Confirmed
} else {
InlineAssistStatus::Canceled
}
}
fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
@@ -1202,6 +1257,42 @@ impl InlineAssistant {
.collect();
})
}
pub fn observe_assist(
&mut self,
assist_id: InlineAssistId,
) -> async_watch::Receiver<AssistStatus> {
if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
rx.clone()
} else {
let (tx, rx) = async_watch::channel(AssistStatus::Idle);
self.assist_observations.insert(assist_id, (tx, rx.clone()));
rx
}
}
}
pub enum InlineAssistStatus {
Idle,
Pending,
Done,
Error,
Confirmed,
Canceled,
}
impl InlineAssistStatus {
pub(crate) fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
pub(crate) fn is_confirmed(&self) -> bool {
matches!(self, Self::Confirmed)
}
pub(crate) fn is_done(&self) -> bool {
matches!(self, Self::Done)
}
}
struct EditorInlineAssists {
@@ -2187,7 +2278,7 @@ impl InlineAssist {
struct InlineAssistantError;
let id =
NotificationId::composite::<InlineAssistantError>(
NotificationId::identified::<InlineAssistantError>(
assist_id.0,
);
@@ -2199,6 +2290,8 @@ impl InlineAssist {
if assist.decorations.is_none() {
this.finish_assist(assist_id, false, cx);
} else if let Some(tx) = this.assist_observations.get(&assist_id) {
tx.0.send(AssistStatus::Finished).ok();
}
}
})
@@ -2861,7 +2954,7 @@ impl CodegenAlternative {
model_provider: model_provider_id.to_string(),
response_latency,
error_message,
language_name: language_name.map(|name| name.to_proto()),
language_name,
});
}

View File

@@ -1,829 +0,0 @@
use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
use editor::ProposedChangesEditor;
use futures::{future, TryFutureExt as _};
use gpui::{AppContext, AsyncAppContext, Model, SharedString};
use language::{AutoindentMode, Buffer, BufferSnapshot};
use project::{Project, ProjectPath};
use std::{cmp, ops::Range, path::Path, sync::Arc};
use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
#[derive(Clone, Debug)]
pub(crate) struct AssistantPatch {
pub range: Range<language::Anchor>,
pub title: SharedString,
pub edits: Arc<[Result<AssistantEdit>]>,
pub status: AssistantPatchStatus,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum AssistantPatchStatus {
Pending,
Ready,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct AssistantEdit {
pub path: String,
pub kind: AssistantEditKind,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AssistantEditKind {
Update {
old_text: String,
new_text: String,
description: String,
},
Create {
new_text: String,
description: String,
},
InsertBefore {
old_text: String,
new_text: String,
description: String,
},
InsertAfter {
old_text: String,
new_text: String,
description: String,
},
Delete {
old_text: String,
},
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct ResolvedPatch {
pub edit_groups: HashMap<Model<Buffer>, Vec<ResolvedEditGroup>>,
pub errors: Vec<AssistantPatchResolutionError>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ResolvedEditGroup {
pub context_range: Range<language::Anchor>,
pub edits: Vec<ResolvedEdit>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ResolvedEdit {
range: Range<language::Anchor>,
new_text: String,
description: Option<String>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct AssistantPatchResolutionError {
pub edit_ix: usize,
pub message: String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
// A measure of the currently quality of an in-progress fuzzy search.
//
// Uses 60 bits to store a numeric cost, and 4 bits to store the preceding
// operation in the search.
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
score: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(score: u32, direction: SearchDirection) -> Self {
Self { score, direction }
}
}
impl ResolvedPatch {
pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) {
for (buffer, groups) in &self.edit_groups {
let branch = editor.branch_buffer_for_base(buffer).unwrap();
Self::apply_edit_groups(groups, &branch, cx);
}
editor.recalculate_all_buffer_diffs();
}
fn apply_edit_groups(
groups: &Vec<ResolvedEditGroup>,
buffer: &Model<Buffer>,
cx: &mut AppContext,
) {
let mut edits = Vec::new();
for group in groups {
for suggestion in &group.edits {
edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
}
}
buffer.update(cx, |buffer, cx| {
buffer.edit(
edits,
Some(AutoindentMode::Block {
original_indent_columns: Vec::new(),
}),
cx,
);
});
}
}
impl ResolvedEdit {
pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
let range = &self.range;
let other_range = &other.range;
// Don't merge if we don't contain the other suggestion.
if range.start.cmp(&other_range.start, buffer).is_gt()
|| range.end.cmp(&other_range.end, buffer).is_lt()
{
return false;
}
let other_offset_range = other_range.to_offset(buffer);
let offset_range = range.to_offset(buffer);
// If the other range is empty at the start of this edit's range, combine the new text
if other_offset_range.is_empty() && other_offset_range.start == offset_range.start {
self.new_text = format!("{}\n{}", other.new_text, self.new_text);
self.range.start = other_range.start;
if let Some((description, other_description)) =
self.description.as_mut().zip(other.description.as_ref())
{
*description = format!("{}\n{}", other_description, description)
}
} else {
if let Some((description, other_description)) =
self.description.as_mut().zip(other.description.as_ref())
{
description.push('\n');
description.push_str(other_description);
}
}
true
}
}
impl AssistantEdit {
pub fn new(
path: Option<String>,
operation: Option<String>,
old_text: Option<String>,
new_text: Option<String>,
description: Option<String>,
) -> Result<Self> {
let path = path.ok_or_else(|| anyhow!("missing path"))?;
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
let kind = match operation.as_str() {
"update" => AssistantEditKind::Update {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"insert_before" => AssistantEditKind::InsertBefore {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"insert_after" => AssistantEditKind::InsertAfter {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"delete" => AssistantEditKind::Delete {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
},
"create" => AssistantEditKind::Create {
description: description.ok_or_else(|| anyhow!("missing description"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
},
_ => Err(anyhow!("unknown operation {operation:?}"))?,
};
Ok(Self { path, kind })
}
pub async fn resolve(
&self,
project: Model<Project>,
mut cx: AsyncAppContext,
) -> Result<(Model<Buffer>, ResolvedEdit)> {
let path = self.path.clone();
let kind = self.kind.clone();
let buffer = project
.update(&mut cx, |project, cx| {
let project_path = project
.find_project_path(Path::new(&path), cx)
.or_else(|| {
// If we couldn't find a project path for it, put it in the active worktree
// so that when we create the buffer, it can be saved.
let worktree = project
.active_entry()
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
.or_else(|| project.worktrees(cx).next())?;
let worktree = worktree.read(cx);
Some(ProjectPath {
worktree_id: worktree.id(),
path: Arc::from(Path::new(&path)),
})
})
.with_context(|| format!("worktree not found for {:?}", path))?;
anyhow::Ok(project.open_buffer(project_path, cx))
})??
.await?;
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
let suggestion = cx
.background_executor()
.spawn(async move { kind.resolve(&snapshot) })
.await;
Ok((buffer, suggestion))
}
}
impl AssistantEditKind {
fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
match self {
Self::Update {
old_text,
new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
ResolvedEdit {
range,
new_text,
description: Some(description),
}
}
Self::Create {
new_text,
description,
} => ResolvedEdit {
range: text::Anchor::MIN..text::Anchor::MAX,
description: Some(description),
new_text,
},
Self::InsertBefore {
old_text,
mut new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
new_text.push('\n');
ResolvedEdit {
range: range.start..range.start,
new_text,
description: Some(description),
}
}
Self::InsertAfter {
old_text,
mut new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
new_text.insert(0, '\n');
ResolvedEdit {
range: range.end..range.end,
new_text,
description: Some(description),
}
}
Self::Delete { old_text } => {
let range = Self::resolve_location(&snapshot, &old_text);
ResolvedEdit {
range,
new_text: String::new(),
description: None,
}
}
}
}
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
const INSERTION_COST: u32 = 3;
const WHITESPACE_INSERTION_COST: u32 = 1;
const DELETION_COST: u32 = 3;
const WHITESPACE_DELETION_COST: u32 = 1;
const EQUALITY_BONUS: u32 = 5;
struct Matrix {
cols: usize,
data: Vec<SearchState>,
}
impl Matrix {
fn new(rows: usize, cols: usize) -> Self {
Matrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
for (row, query_byte) in search_query.bytes().enumerate() {
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).score.saturating_sub(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix
.get(row + 1, col)
.score
.saturating_sub(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).score.saturating_add(EQUALITY_BONUS)
} else {
matrix
.get(row, col)
.score
.saturating_sub(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.max(left).max(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_score = 0;
for col in 1..=buffer_len {
let score = matrix.get(query_len, col).score;
if score > best_score {
best_score = score;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
}
impl AssistantPatch {
pub(crate) async fn resolve(
&self,
project: Model<Project>,
cx: &mut AsyncAppContext,
) -> ResolvedPatch {
let mut resolve_tasks = Vec::new();
for (ix, edit) in self.edits.iter().enumerate() {
if let Ok(edit) = edit.as_ref() {
resolve_tasks.push(
edit.resolve(project.clone(), cx.clone())
.map_err(move |error| (ix, error)),
);
}
}
let edits = future::join_all(resolve_tasks).await;
let mut errors = Vec::new();
let mut edits_by_buffer = HashMap::default();
for entry in edits {
match entry {
Ok((buffer, edit)) => {
edits_by_buffer
.entry(buffer)
.or_insert_with(Vec::new)
.push(edit);
}
Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
edit_ix,
message: error.to_string(),
}),
}
}
// Expand the context ranges of each edit and group edits with overlapping context ranges.
let mut edit_groups_by_buffer = HashMap::default();
for (buffer, edits) in edits_by_buffer {
if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
}
}
ResolvedPatch {
edit_groups: edit_groups_by_buffer,
errors,
}
}
fn group_edits(
mut edits: Vec<ResolvedEdit>,
snapshot: &text::BufferSnapshot,
) -> Vec<ResolvedEditGroup> {
let mut edit_groups = Vec::<ResolvedEditGroup>::new();
// Sort edits by their range so that earlier, larger ranges come first
edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
// Merge overlapping edits
edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
// Create context ranges for each edit
for edit in edits {
let context_range = {
let edit_point_range = edit.range.to_point(&snapshot);
let start_row = edit_point_range.start.row.saturating_sub(5);
let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
let start = snapshot.anchor_before(Point::new(start_row, 0));
let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
start..end
};
if let Some(last_group) = edit_groups.last_mut() {
if last_group
.context_range
.end
.cmp(&context_range.start, &snapshot)
.is_ge()
{
// Merge with the previous group if context ranges overlap
last_group.context_range.end = context_range.end;
last_group.edits.push(edit);
} else {
// Create a new group
edit_groups.push(ResolvedEditGroup {
context_range,
edits: vec![edit],
});
}
} else {
// Create the first group
edit_groups.push(ResolvedEditGroup {
context_range,
edits: vec![edit],
});
}
}
edit_groups
}
pub fn path_count(&self) -> usize {
self.paths().count()
}
pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
let mut prev_path = None;
self.edits.iter().filter_map(move |edit| {
if let Ok(edit) = edit {
let path = Some(edit.path.as_str());
if path != prev_path {
prev_path = path;
return path;
}
}
None
})
}
}
impl PartialEq for AssistantPatch {
fn eq(&self, other: &Self) -> bool {
self.range == other.range
&& self.title == other.title
&& Arc::ptr_eq(&self.edits, &other.edits)
}
}
impl Eq for AssistantPatch {}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{AppContext, Context};
use language::{
language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher,
};
use settings::SettingsStore;
use text::{OffsetRangeExt, Point};
use ui::BorrowAppContext;
use unindent::Unindent as _;
#[gpui::test]
fn test_resolve_location(cx: &mut AppContext) {
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
" Lorem\n",
" ipsum\n",
" dolor sit amet\n",
" consecteur",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
AssistantEditKind::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
Point::new(1, 0)..Point::new(2, 18)
);
}
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
"fn foo1(a: usize) -> usize {\n",
" 40\n",
"}\n",
"\n",
"fn foo2(b: usize) -> usize {\n",
" 42\n",
"}\n",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
AssistantEditKind::resolve_location(&snapshot, "fn foo1(b: usize) {\n40\n}")
.to_point(&snapshot),
Point::new(0, 0)..Point::new(2, 1)
);
}
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
"fn main() {\n",
" Foo\n",
" .bar()\n",
" .baz()\n",
" .qux()\n",
"}\n",
"\n",
"fn foo2(b: usize) -> usize {\n",
" 42\n",
"}\n",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
AssistantEditKind::resolve_location(&snapshot, "Foo.bar.baz.qux()")
.to_point(&snapshot),
Point::new(1, 0)..Point::new(4, 14)
);
}
}
#[gpui::test]
fn test_resolve_edits(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
assert_edits(
"
/// A person
struct Person {
name: String,
age: usize,
}
/// A dog
struct Dog {
weight: f32,
}
impl Person {
fn name(&self) -> &str {
&self.name
}
}
"
.unindent(),
vec![
AssistantEditKind::Update {
old_text: "
name: String,
"
.unindent(),
new_text: "
first_name: String,
last_name: String,
"
.unindent(),
description: "".into(),
},
AssistantEditKind::Update {
old_text: "
fn name(&self) -> &str {
&self.name
}
"
.unindent(),
new_text: "
fn name(&self) -> String {
format!(\"{} {}\", self.first_name, self.last_name)
}
"
.unindent(),
description: "".into(),
},
],
"
/// A person
struct Person {
first_name: String,
last_name: String,
age: usize,
}
/// A dog
struct Dog {
weight: f32,
}
impl Person {
fn name(&self) -> String {
format!(\"{} {}\", self.first_name, self.last_name)
}
}
"
.unindent(),
cx,
);
// Ensure InsertBefore merges correctly with Update of the same text
assert_edits(
"
fn foo() {
}
"
.unindent(),
vec![
AssistantEditKind::InsertBefore {
old_text: "
fn foo() {"
.unindent(),
new_text: "
fn bar() {
qux();
}"
.unindent(),
description: "implement bar".into(),
},
AssistantEditKind::Update {
old_text: "
fn foo() {
}"
.unindent(),
new_text: "
fn foo() {
bar();
}"
.unindent(),
description: "call bar in foo".into(),
},
AssistantEditKind::InsertAfter {
old_text: "
fn foo() {
}
"
.unindent(),
new_text: "
fn qux() {
// todo
}
"
.unindent(),
description: "implement qux".into(),
},
],
"
fn bar() {
qux();
}
fn foo() {
bar();
}
fn qux() {
// todo
}
"
.unindent(),
cx,
);
}
#[track_caller]
fn assert_edits(
old_text: String,
edits: Vec<AssistantEditKind>,
new_text: String,
cx: &mut AppContext,
) {
let buffer =
cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
let resolved_edits = edits
.into_iter()
.map(|kind| kind.resolve(&snapshot))
.collect();
let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
let actual_new_text = buffer.read(cx).text();
pretty_assertions::assert_eq!(actual_new_text, new_text);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(language::tree_sitter_rust::LANGUAGE.into()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View File

@@ -45,6 +45,15 @@ pub struct ProjectSlashCommandPromptContext {
pub context_buffer: String,
}
/// Context required to generate a workflow step resolution prompt.
#[derive(Debug, Serialize)]
pub struct StepResolutionContext {
/// The full context, including <step>...</step> tags
pub workflow_context: String,
/// The text of the specific step from the context to resolve
pub step_to_resolve: String,
}
pub struct PromptLoadingParams<'a> {
pub fs: Arc<dyn Fs>,
pub repo_path: Option<PathBuf>,

View File

@@ -1,4 +1,3 @@
use super::create_label_for_command;
use anyhow::{anyhow, Result};
use assistant_slash_command::{
AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
@@ -7,9 +6,9 @@ use assistant_slash_command::{
use collections::HashMap;
use context_servers::{
manager::{ContextServer, ContextServerManager},
types::Prompt,
protocol::PromptInfo,
};
use gpui::{AppContext, Task, WeakView, WindowContext};
use gpui::{Task, WeakView, WindowContext};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
@@ -19,11 +18,11 @@ use workspace::Workspace;
pub struct ContextServerSlashCommand {
server_id: String,
prompt: Prompt,
prompt: PromptInfo,
}
impl ContextServerSlashCommand {
pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
Self {
server_id: server.id.clone(),
prompt,
@@ -36,28 +35,12 @@ impl SlashCommand for ContextServerSlashCommand {
self.prompt.name.clone()
}
fn label(&self, cx: &AppContext) -> language::CodeLabel {
let mut parts = vec![self.prompt.name.as_str()];
if let Some(args) = &self.prompt.arguments {
if let Some(arg) = args.first() {
parts.push(arg.name.as_str());
}
}
create_label_for_command(&parts[0], &parts[1..], cx)
}
fn description(&self) -> String {
match &self.prompt.description {
Some(desc) => desc.clone(),
None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
}
format!("Run context server command: {}", self.prompt.name)
}
fn menu_text(&self) -> String {
match &self.prompt.description {
Some(desc) => desc.clone(),
None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
}
format!("Run '{}' from {}", self.prompt.name, self.server_id)
}
fn requires_argument(&self) -> bool {
@@ -171,7 +154,7 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
if arguments.is_empty() {
return Err(anyhow!("No arguments given"));
}
@@ -187,7 +170,7 @@ fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String,
}
}
fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
match &prompt.arguments {
Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported"
@@ -216,7 +199,7 @@ fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<Str
/// MCP servers can return prompts with multiple arguments. Since we only
/// support one argument, we ignore all others. This is the necessary predicate
/// for this.
pub fn acceptable_prompt(prompt: &Prompt) -> bool {
pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
match &prompt.arguments {
None => true,
Some(args) if args.len() <= 1 => true,

View File

@@ -18,8 +18,6 @@ pub(crate) struct WorkflowSlashCommand {
}
impl WorkflowSlashCommand {
pub const NAME: &'static str = "workflow";
pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
Self { prompt_builder }
}
@@ -27,7 +25,7 @@ impl WorkflowSlashCommand {
impl SlashCommand for WorkflowSlashCommand {
fn name(&self) -> String {
Self::NAME.into()
"workflow".into()
}
fn description(&self) -> String {

View File

@@ -38,10 +38,7 @@ impl Settings for SlashCommandSettings {
fn load(sources: SettingsSources<Self::FileContent>, _cx: &mut AppContext) -> Result<Self> {
SettingsSources::<Self::FileContent>::json_merge_with(
[sources.default]
.into_iter()
.chain(sources.user)
.chain(sources.server),
[sources.default].into_iter().chain(sources.user),
)
}
}

View File

@@ -414,7 +414,7 @@ impl TerminalInlineAssist {
struct InlineAssistantError;
let id =
NotificationId::composite::<InlineAssistantError>(
NotificationId::identified::<InlineAssistantError>(
assist_id.0,
);

View File

@@ -0,0 +1,507 @@
use crate::{AssistantPanel, InlineAssistId, InlineAssistant};
use anyhow::{anyhow, Context as _, Result};
use collections::HashMap;
use editor::Editor;
use gpui::AsyncAppContext;
use gpui::{Model, Task, UpdateGlobal as _, View, WeakView, WindowContext};
use language::{Buffer, BufferSnapshot};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{ops::Range, path::Path, sync::Arc};
use text::Bias;
use workspace::Workspace;
#[derive(Debug)]
pub(crate) struct WorkflowStep {
pub range: Range<language::Anchor>,
pub leading_tags_end: text::Anchor,
pub trailing_tag_start: Option<text::Anchor>,
pub edits: Arc<[Result<WorkflowStepEdit>]>,
pub resolution_task: Option<Task<()>>,
pub resolution: Option<Arc<Result<WorkflowStepResolution>>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct WorkflowStepEdit {
pub path: String,
pub kind: WorkflowStepEditKind,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct WorkflowStepResolution {
pub title: String,
pub suggestion_groups: HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WorkflowSuggestionGroup {
pub context_range: Range<language::Anchor>,
pub suggestions: Vec<WorkflowSuggestion>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WorkflowSuggestion {
Update {
range: Range<language::Anchor>,
description: String,
},
CreateFile {
description: String,
},
InsertBefore {
position: language::Anchor,
description: String,
},
InsertAfter {
position: language::Anchor,
description: String,
},
Delete {
range: Range<language::Anchor>,
},
}
impl WorkflowSuggestion {
pub fn range(&self) -> Range<language::Anchor> {
match self {
Self::Update { range, .. } => range.clone(),
Self::CreateFile { .. } => language::Anchor::MIN..language::Anchor::MAX,
Self::InsertBefore { position, .. } | Self::InsertAfter { position, .. } => {
*position..*position
}
Self::Delete { range, .. } => range.clone(),
}
}
pub fn description(&self) -> Option<&str> {
match self {
Self::Update { description, .. }
| Self::CreateFile { description }
| Self::InsertBefore { description, .. }
| Self::InsertAfter { description, .. } => Some(description),
Self::Delete { .. } => None,
}
}
fn description_mut(&mut self) -> Option<&mut String> {
match self {
Self::Update { description, .. }
| Self::CreateFile { description }
| Self::InsertBefore { description, .. }
| Self::InsertAfter { description, .. } => Some(description),
Self::Delete { .. } => None,
}
}
pub fn try_merge(&mut self, other: &Self, buffer: &BufferSnapshot) -> bool {
let range = self.range();
let other_range = other.range();
// Don't merge if we don't contain the other suggestion.
if range.start.cmp(&other_range.start, buffer).is_gt()
|| range.end.cmp(&other_range.end, buffer).is_lt()
{
return false;
}
if let Some(description) = self.description_mut() {
if let Some(other_description) = other.description() {
description.push('\n');
description.push_str(other_description);
}
}
true
}
pub fn show(
&self,
editor: &View<Editor>,
excerpt_id: editor::ExcerptId,
workspace: &WeakView<Workspace>,
assistant_panel: &View<AssistantPanel>,
cx: &mut WindowContext,
) -> Option<InlineAssistId> {
let mut initial_transaction_id = None;
let initial_prompt;
let suggestion_range;
let buffer = editor.read(cx).buffer().clone();
let snapshot = buffer.read(cx).snapshot(cx);
match self {
Self::Update {
range, description, ..
} => {
initial_prompt = description.clone();
suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
}
Self::CreateFile { description } => {
initial_prompt = description.clone();
suggestion_range = editor::Anchor::min()..editor::Anchor::min();
}
Self::InsertBefore {
position,
description,
..
} => {
let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
initial_prompt = description.clone();
suggestion_range = buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
let line_start = buffer.insert_empty_line(position, true, true, cx);
initial_transaction_id = buffer.end_transaction(cx);
buffer.refresh_preview(cx);
let line_start = buffer.read(cx).anchor_before(line_start);
line_start..line_start
});
}
Self::InsertAfter {
position,
description,
..
} => {
let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
initial_prompt = description.clone();
suggestion_range = buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
let line_start = buffer.insert_empty_line(position, true, true, cx);
initial_transaction_id = buffer.end_transaction(cx);
buffer.refresh_preview(cx);
let line_start = buffer.read(cx).anchor_before(line_start);
line_start..line_start
});
}
Self::Delete { range, .. } => {
initial_prompt = "Delete".to_string();
suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
}
}
InlineAssistant::update_global(cx, |inline_assistant, cx| {
Some(inline_assistant.suggest_assist(
editor,
suggestion_range,
initial_prompt,
initial_transaction_id,
false,
Some(workspace.clone()),
Some(assistant_panel),
cx,
))
})
}
}
impl WorkflowStepEdit {
pub fn new(
path: Option<String>,
operation: Option<String>,
search: Option<String>,
description: Option<String>,
) -> Result<Self> {
let path = path.ok_or_else(|| anyhow!("missing path"))?;
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
let kind = match operation.as_str() {
"update" => WorkflowStepEditKind::Update {
search: search.ok_or_else(|| anyhow!("missing search"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"insert_before" => WorkflowStepEditKind::InsertBefore {
search: search.ok_or_else(|| anyhow!("missing search"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"insert_after" => WorkflowStepEditKind::InsertAfter {
search: search.ok_or_else(|| anyhow!("missing search"))?,
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
"delete" => WorkflowStepEditKind::Delete {
search: search.ok_or_else(|| anyhow!("missing search"))?,
},
"create" => WorkflowStepEditKind::Create {
description: description.ok_or_else(|| anyhow!("missing description"))?,
},
_ => Err(anyhow!("unknown operation {operation:?}"))?,
};
Ok(Self { path, kind })
}
pub async fn resolve(
&self,
project: Model<Project>,
mut cx: AsyncAppContext,
) -> Result<(Model<Buffer>, super::WorkflowSuggestion)> {
let path = self.path.clone();
let kind = self.kind.clone();
let buffer = project
.update(&mut cx, |project, cx| {
let project_path = project
.find_project_path(Path::new(&path), cx)
.or_else(|| {
// If we couldn't find a project path for it, put it in the active worktree
// so that when we create the buffer, it can be saved.
let worktree = project
.active_entry()
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
.or_else(|| project.worktrees(cx).next())?;
let worktree = worktree.read(cx);
Some(ProjectPath {
worktree_id: worktree.id(),
path: Arc::from(Path::new(&path)),
})
})
.with_context(|| format!("worktree not found for {:?}", path))?;
anyhow::Ok(project.open_buffer(project_path, cx))
})??
.await?;
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
let suggestion = cx
.background_executor()
.spawn(async move {
match kind {
WorkflowStepEditKind::Update {
search,
description,
} => {
let range = Self::resolve_location(&snapshot, &search);
WorkflowSuggestion::Update { range, description }
}
WorkflowStepEditKind::Create { description } => {
WorkflowSuggestion::CreateFile { description }
}
WorkflowStepEditKind::InsertBefore {
search,
description,
} => {
let range = Self::resolve_location(&snapshot, &search);
WorkflowSuggestion::InsertBefore {
position: range.start,
description,
}
}
WorkflowStepEditKind::InsertAfter {
search,
description,
} => {
let range = Self::resolve_location(&snapshot, &search);
WorkflowSuggestion::InsertAfter {
position: range.end,
description,
}
}
WorkflowStepEditKind::Delete { search } => {
let range = Self::resolve_location(&snapshot, &search);
WorkflowSuggestion::Delete { range }
}
}
})
.await;
Ok((buffer, suggestion))
}
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
const INSERTION_SCORE: f64 = -1.0;
const DELETION_SCORE: f64 = -1.0;
const REPLACEMENT_SCORE: f64 = -1.0;
const EQUALITY_SCORE: f64 = 5.0;
struct Matrix {
cols: usize,
data: Vec<f64>,
}
impl Matrix {
fn new(rows: usize, cols: usize) -> Self {
Matrix {
cols,
data: vec![0.0; rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> f64 {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, value: f64) {
self.data[row * self.cols + col] = value;
}
}
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
for (i, query_byte) in search_query.bytes().enumerate() {
for (j, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let match_score = if query_byte == *buffer_byte {
EQUALITY_SCORE
} else {
REPLACEMENT_SCORE
};
let up = matrix.get(i + 1, j) + DELETION_SCORE;
let left = matrix.get(i, j + 1) + INSERTION_SCORE;
let diagonal = matrix.get(i, j) + match_score;
let score = up.max(left.max(diagonal)).max(0.);
matrix.set(i + 1, j + 1, score);
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_score = 0.0;
for col in 1..=buffer_len {
let score = matrix.get(query_len, col);
if score > best_score {
best_score = score;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
let up = matrix.get(query_ix - 1, buffer_ix);
let left = matrix.get(query_ix, buffer_ix - 1);
if current == left + INSERTION_SCORE {
buffer_ix -= 1;
} else if current == up + DELETION_SCORE {
query_ix -= 1;
} else {
query_ix -= 1;
buffer_ix -= 1;
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
end.column = buffer.line_len(end.row);
buffer.anchor_after(start)..buffer.anchor_before(end)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "operation")]
pub enum WorkflowStepEditKind {
/// Rewrites the specified text entirely based on the given description.
/// This operation completely replaces the given text.
Update {
/// A string in the source text to apply the update to.
search: String,
/// A brief description of the transformation to apply to the symbol.
description: String,
},
/// Creates a new file with the given path based on the provided description.
/// This operation adds a new file to the codebase.
Create {
/// A brief description of the file to be created.
description: String,
},
/// Inserts text before the specified text in the source file.
InsertBefore {
/// A string in the source text to insert text before.
search: String,
/// A brief description of how the new text should be generated.
description: String,
},
/// Inserts text after the specified text in the source file.
InsertAfter {
/// A string in the source text to insert text after.
search: String,
/// A brief description of how the new text should be generated.
description: String,
},
/// Deletes the specified symbol from the containing file.
Delete {
/// A string in the source text to delete.
search: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{AppContext, Context};
use text::{OffsetRangeExt, Point};
#[gpui::test]
fn test_resolve_location(cx: &mut AppContext) {
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
" Lorem\n",
" ipsum\n",
" dolor sit amet\n",
" consecteur",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
WorkflowStepEdit::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
Point::new(1, 0)..Point::new(2, 18)
);
}
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
"fn foo1(a: usize) -> usize {\n",
" 42\n",
"}\n",
"\n",
"fn foo2(b: usize) -> usize {\n",
" 42\n",
"}\n",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
WorkflowStepEdit::resolve_location(&snapshot, "fn foo1(b: usize) {\n42\n}")
.to_point(&snapshot),
Point::new(0, 0)..Point::new(2, 1)
);
}
{
let buffer = cx.new_model(|cx| {
Buffer::local(
concat!(
"fn main() {\n",
" Foo\n",
" .bar()\n",
" .baz()\n",
" .qux()\n",
"}\n",
"\n",
"fn foo2(b: usize) -> usize {\n",
" 42\n",
"}\n",
),
cx,
)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
WorkflowStepEdit::resolve_location(&snapshot, "Foo.bar.baz.qux()")
.to_point(&snapshot),
Point::new(1, 0)..Point::new(4, 14)
);
}
}
}

View File

@@ -130,7 +130,7 @@ impl Settings for AutoUpdateSetting {
type FileContent = Option<AutoUpdateSettingContent>;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
let auto_update = [sources.server, sources.release_channel, sources.user]
let auto_update = [sources.release_channel, sources.user]
.into_iter()
.find_map(|value| value.copied().flatten())
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
@@ -464,7 +464,6 @@ impl AutoUpdater {
smol::fs::create_dir_all(&platform_dir).await.ok();
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
if smol::fs::metadata(&version_path).await.is_err() {
log::info!("downloading zed-remote-server {os} {arch}");
download_remote_server_binary(&version_path, release, client, cx).await?;

View File

@@ -34,8 +34,8 @@ postage.workspace = true
rand.workspace = true
release_channel.workspace = true
rpc = { workspace = true, features = ["gpui"] }
rustls-native-certs.workspace = true
rustls.workspace = true
rustls-native-certs.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -4,7 +4,6 @@ pub mod test;
mod socks;
pub mod telemetry;
pub mod user;
pub mod zed_urls;
use anyhow::{anyhow, bail, Context as _, Result};
use async_recursion::async_recursion;
@@ -142,7 +141,6 @@ impl Settings for ProxySettings {
Ok(Self {
proxy: sources
.user
.or(sources.server)
.and_then(|value| value.proxy.clone())
.or(sources.default.proxy.clone()),
})
@@ -474,21 +472,15 @@ impl settings::Settings for TelemetrySettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
Ok(Self {
diagnostics: sources
.user
.as_ref()
.or(sources.server.as_ref())
.and_then(|v| v.diagnostics)
.unwrap_or(
sources
.default
.diagnostics
.ok_or_else(Self::missing_default)?,
),
diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
sources
.default
.diagnostics
.ok_or_else(Self::missing_default)?,
),
metrics: sources
.user
.as_ref()
.or(sources.server.as_ref())
.and_then(|v| v.metrics)
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
})
@@ -1031,7 +1023,7 @@ impl Client {
&self,
http: Arc<HttpClientWithUrl>,
release_channel: Option<ReleaseChannel>,
) -> impl Future<Output = Result<url::Url>> {
) -> impl Future<Output = Result<Url>> {
#[cfg(any(test, feature = "test-support"))]
let url_override = self.rpc_url.read().clone();
@@ -1125,7 +1117,7 @@ impl Client {
// for us from the RPC URL.
//
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
let mut request = rpc_url.into_client_request()?;
// We then modify the request to add our desired headers.
let request_headers = request.headers_mut();
@@ -1164,7 +1156,6 @@ impl Client {
.with_root_certificates(root_store)
.with_no_client_auth()
};
let (stream, _) =
async_tungstenite::async_tls::client_async_tls_with_connector(
request,

View File

@@ -1,19 +0,0 @@
//! Contains helper functions for constructing URLs to various Zed-related pages.
//!
//! These URLs will adapt to the configured server URL in order to construct
//! links appropriate for the environment (e.g., by linking to a local copy of
//! zed.dev in development).
use gpui::AppContext;
use settings::Settings;
use crate::ClientSettings;
fn server_url(cx: &AppContext) -> &str {
&ClientSettings::get_global(cx).server_url
}
/// Returns the URL to the account page on zed.dev.
pub fn account_url(cx: &AppContext) -> String {
format!("{server_url}/account", server_url = server_url(cx))
}

View File

@@ -38,6 +38,7 @@ futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
isahc_http_client.workspace = true
jsonwebtoken.workspace = true
live_kit_server.workspace = true
log.workspace = true
@@ -48,7 +49,6 @@ prometheus = "0.13"
prost.workspace = true
rand.workspace = true
reqwest = { version = "0.11", features = ["json"] }
reqwest_client.workspace = true
rpc.workspace = true
rustc-demangle.workspace = true
scrypt = "0.11"
@@ -67,7 +67,7 @@ telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["full"] }
tokio.workspace = true
toml.workspace = true
tower = "0.4"
tower-http = { workspace = true, features = ["trace"] }

View File

@@ -199,12 +199,6 @@ spec:
secretKeyRef:
name: slack
key: panics_webhook
- name: STRIPE_API_KEY
valueFrom:
secretKeyRef:
name: stripe
key: api_key
optional: true
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
value: "1000"
- name: SUPERMAVEN_ADMIN_API_KEY

View File

@@ -78,10 +78,10 @@ CREATE TABLE "worktree_entries" (
"id" INTEGER NOT NULL,
"is_dir" BOOL NOT NULL,
"path" VARCHAR NOT NULL,
"canonical_path" TEXT,
"inode" INTEGER NOT NULL,
"mtime_seconds" INTEGER NOT NULL,
"mtime_nanos" INTEGER NOT NULL,
"is_symlink" BOOL NOT NULL,
"is_external" BOOL NOT NULL,
"is_ignored" BOOL NOT NULL,
"is_deleted" BOOL NOT NULL,

View File

@@ -1,2 +0,0 @@
ALTER TABLE worktree_entries ADD COLUMN canonical_path text;
ALTER TABLE worktree_entries ALTER COLUMN is_symlink SET DEFAULT false;

View File

@@ -1,12 +0,0 @@
create table billing_events (
id serial primary key,
idempotency_key uuid not null default gen_random_uuid(),
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
input_tokens bigint not null default 0,
input_cache_creation_tokens bigint not null default 0,
input_cache_read_tokens bigint not null default 0,
output_tokens bigint not null default 0
);
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);

View File

@@ -1,3 +1,7 @@
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, bail, Context};
use axum::{
extract::{self, Query},
@@ -5,35 +9,28 @@ use axum::{
Extension, Json, Router,
};
use chrono::{DateTime, SecondsFormat, Utc};
use collections::HashSet;
use reqwest::StatusCode;
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
CreateBillingPortalSessionFlowDataAfterCompletion,
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
Subscription, SubscriptionId, SubscriptionStatus,
};
use util::ResultExt;
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::rpc::{ResultExt as _, Server};
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
UpdateBillingSubscriptionParams,
},
stripe_billing::StripeBilling,
};
use crate::{
db::{billing_subscription::StripeSubscriptionStatus, UserId},
llm::db::LlmDatabase,
};
use crate::rpc::ResultExt as _;
use crate::{AppState, Error, Result};
pub fn router() -> Router {
@@ -50,7 +47,6 @@ pub fn router() -> Router {
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
.route("/billing/monthly_spend", get(get_monthly_spend))
}
#[derive(Debug, Deserialize)]
@@ -91,7 +87,6 @@ struct UpdateBillingPreferencesBody {
async fn update_billing_preferences(
Extension(app): Extension<Arc<AppState>>,
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
) -> Result<Json<BillingPreferencesResponse>> {
let user = app
@@ -124,8 +119,6 @@ async fn update_billing_preferences(
.await?
};
rpc_server.refresh_llm_tokens_for_user(user.id).await;
Ok(Json(BillingPreferencesResponse {
max_monthly_llm_usage_spending_in_cents: billing_preferences
.max_monthly_llm_usage_spending_in_cents,
@@ -204,22 +197,12 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
let Some((stripe_client, stripe_access_price_id)) = app
.stripe_client
.clone()
.zip(app.config.stripe_llm_access_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
@@ -243,14 +226,26 @@ async fn create_billing_subscription(
customer.id
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
let checkout_session_url = stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?;
let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_access_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
params.success_url = Some(&success_url);
CheckoutSession::create(&stripe_client, params).await?
};
Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url,
checkout_session_url: checkout_session
.url
.ok_or_else(|| anyhow!("no checkout session URL"))?,
}))
}
@@ -405,7 +400,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
@@ -416,9 +411,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &rpc_server, &stripe_client)
.await
.log_err();
poll_stripe_events(&app, &stripe_client).await.log_err();
executor.sleep(POLL_EVENTS_INTERVAL).await;
}
@@ -428,7 +421,6 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
fn event_type_to_string(event_type: EventType) -> String {
@@ -453,28 +445,29 @@ async fn poll_stripe_events(
let mut pages_of_already_processed_events = 0;
let mut unprocessed_events = Vec::new();
log::info!(
"Stripe events: starting retrieval for {}",
event_types.join(", ")
);
let mut params = ListEvents::new();
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let mut event_pages = stripe::Event::list(&stripe_client, &params)
.await?
.paginate(params);
loop {
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP {
log::info!("saw {pages_of_already_processed_events} pages of already-processed events: stopping event retrieval");
break;
}
log::info!("retrieving events from Stripe: {}", event_types.join(", "));
let mut params = ListEvents::new();
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let events = stripe::Event::list(stripe_client, &params).await?;
let processed_event_ids = {
let event_ids = event_pages
.page
let event_ids = &events
.data
.iter()
.map(|event| event.id.as_str())
.collect::<Vec<_>>();
app.db
.get_processed_stripe_events_by_event_ids(&event_ids)
.get_processed_stripe_events_by_event_ids(event_ids)
.await?
.into_iter()
.map(|event| event.stripe_event_id)
@@ -482,13 +475,13 @@ async fn poll_stripe_events(
};
let mut processed_events_in_page = 0;
let events_in_page = event_pages.page.data.len();
for event in &event_pages.page.data {
let events_in_page = events.data.len();
for event in events.data {
if processed_event_ids.contains(&event.id.to_string()) {
processed_events_in_page += 1;
log::debug!("Stripe events: already processed '{}', skipping", event.id);
log::debug!("Stripe event {} already processed: skipping", event.id);
} else {
unprocessed_events.push(event.clone());
unprocessed_events.push(event);
}
}
@@ -496,21 +489,15 @@ async fn poll_stripe_events(
pages_of_already_processed_events += 1;
}
if event_pages.page.has_more {
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
{
log::info!("Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events");
break;
} else {
log::info!("Stripe events: retrieving next page");
event_pages = event_pages.next(&stripe_client).await?;
}
} else {
if !events.has_more {
break;
}
}
log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
log::info!(
"unprocessed events from Stripe: {}",
unprocessed_events.len()
);
// Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
@@ -526,12 +513,12 @@ async fn poll_stripe_events(
// If the event has happened too far in the past, we don't want to
// process it and risk overwriting other more-recent updates.
//
// 1 day was chosen arbitrarily. This could be made longer or shorter.
let one_day = Duration::from_secs(24 * 60 * 60);
let a_day_ago = Utc::now() - one_day;
if a_day_ago.timestamp() > event.created {
// 1 hour was chosen arbitrarily. This could be made longer or shorter.
let one_hour = Duration::from_secs(60 * 60);
let an_hour_ago = Utc::now() - one_hour;
if an_hour_ago.timestamp() > event.created {
log::info!(
"Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
"Stripe event {} is more than {one_hour:?} old, marking as processed",
event_id
);
app.db
@@ -550,7 +537,7 @@ async fn poll_stripe_events(
| EventType::CustomerSubscriptionPaused
| EventType::CustomerSubscriptionResumed
| EventType::CustomerSubscriptionDeleted => {
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
handle_customer_subscription_event(app, stripe_client, event).await
}
_ => Ok(()),
};
@@ -618,7 +605,6 @@ async fn handle_customer_event(
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
@@ -664,52 +650,9 @@ async fn handle_customer_subscription_event(
.await?;
}
// When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access.
rpc_server
.refresh_llm_tokens_for_user(billing_customer.user_id)
.await;
Ok(())
}
#[derive(Debug, Deserialize)]
struct GetMonthlySpendParams {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct GetMonthlySpendResponse {
monthly_spend_in_cents: i32,
}
async fn get_monthly_spend(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetMonthlySpendParams>,
) -> Result<Json<GetMonthlySpendResponse>> {
let user = app
.db
.get_user_by_github_user_id(params.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some(llm_db) = app.llm_db.clone() else {
return Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"LLM database not available".into(),
));
};
let monthly_spend = llm_db
.get_user_spending_for_month(user.id, Utc::now())
.await?
.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
Ok(Json(GetMonthlySpendResponse {
monthly_spend_in_cents: monthly_spend.0 as i32,
}))
}
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
match value {
@@ -772,15 +715,15 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::warn!("failed to retrieve Stripe billing object");
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
log::warn!("failed to retrieve Stripe LLM usage price ID");
return;
};
@@ -789,10 +732,15 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let executor = executor.clone();
async move {
loop {
sync_with_stripe(&app, &llm_db, &stripe_billing)
.await
.context("failed to sync LLM usage to Stripe")
.trace_err();
sync_with_stripe(
&app,
&llm_db,
&stripe_client,
stripe_llm_usage_price_id.clone(),
)
.await
.trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
}
}
@@ -801,44 +749,71 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
async fn sync_with_stripe(
app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: Arc<str>,
) -> anyhow::Result<()> {
let events = llm_db.get_billing_events().await?;
let user_ids = events
.iter()
.map(|(event, _)| event.user_id)
.collect::<HashSet<UserId>>();
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
let subscriptions = app.db.get_active_billing_subscriptions().await?;
for (event, model) in events {
let Some((stripe_db_customer, stripe_db_subscription)) =
stripe_subscriptions.get(&event.user_id)
else {
tracing::warn!(
user_id = event.user_id.0,
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
);
continue;
};
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
.stripe_subscription_id
.parse()
.context("failed to parse stripe subscription id from db")?;
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
.stripe_customer_id
.parse()
.context("failed to parse stripe customer id from db")?;
let stripe_model = stripe_billing.register_model(&model).await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
stripe_billing
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
.await?;
llm_db.consume_billing_event(event.id).await?;
for (customer, subscription) in subscriptions {
update_stripe_subscription(
llm_db,
stripe_client,
&stripe_llm_usage_price_id,
customer,
subscription,
)
.await
.log_err();
}
Ok(())
}
async fn update_stripe_subscription(
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: &Arc<str>,
customer: billing_customer::Model,
subscription: billing_subscription::Model,
) -> Result<(), anyhow::Error> {
let monthly_spending = llm_db
.get_user_spending_for_month(customer.user_id, Utc::now())
.await?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
let monthly_spending_over_free_tier =
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
let mut update_params = stripe::UpdateSubscription {
proration_behavior: Some(
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
),
..Default::default()
};
if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
item.price.as_ref().map_or(false, |price| {
price.id == stripe_llm_usage_price_id.as_ref()
})
}) {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
id: Some(existing_item.id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
} else {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
price: Some(stripe_llm_usage_price_id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
}
Subscription::update(stripe_client, &subscription_id, update_params).await?;
Ok(())
}

View File

@@ -10,8 +10,6 @@
Copy,
derive_more::Add,
derive_more::AddAssign,
derive_more::Sub,
derive_more::SubAssign,
)]
pub struct Cents(pub u32);

View File

@@ -114,31 +114,23 @@ impl Database {
pub async fn get_active_billing_subscriptions(
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| async move {
let mut result = Vec::new();
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
result.push((customer, subscription));
}
Ok(subscriptions)
}
Ok(result)
})
.await
}

View File

@@ -317,7 +317,7 @@ impl Database {
inode: ActiveValue::set(entry.inode as i64),
mtime_seconds: ActiveValue::set(mtime.seconds as i64),
mtime_nanos: ActiveValue::set(mtime.nanos as i32),
canonical_path: ActiveValue::set(entry.canonical_path.clone()),
is_symlink: ActiveValue::set(entry.is_symlink),
is_ignored: ActiveValue::set(entry.is_ignored),
is_external: ActiveValue::set(entry.is_external),
git_status: ActiveValue::set(entry.git_status.map(|status| status as i64)),
@@ -338,7 +338,7 @@ impl Database {
worktree_entry::Column::Inode,
worktree_entry::Column::MtimeSeconds,
worktree_entry::Column::MtimeNanos,
worktree_entry::Column::CanonicalPath,
worktree_entry::Column::IsSymlink,
worktree_entry::Column::IsIgnored,
worktree_entry::Column::GitStatus,
worktree_entry::Column::ScanId,
@@ -735,7 +735,7 @@ impl Database {
seconds: db_entry.mtime_seconds as u64,
nanos: db_entry.mtime_nanos as u32,
}),
canonical_path: db_entry.canonical_path,
is_symlink: db_entry.is_symlink,
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
@@ -838,7 +838,6 @@ impl Database {
.map(|language_server| proto::LanguageServer {
id: language_server.id as u64,
name: language_server.name,
worktree_id: None,
})
.collect(),
dev_server_project_id: project.dev_server_project_id,

View File

@@ -659,7 +659,7 @@ impl Database {
seconds: db_entry.mtime_seconds as u64,
nanos: db_entry.mtime_nanos as u32,
}),
canonical_path: db_entry.canonical_path,
is_symlink: db_entry.is_symlink,
is_ignored: db_entry.is_ignored,
is_external: db_entry.is_external,
git_status: db_entry.git_status.map(|status| status as i32),
@@ -718,7 +718,6 @@ impl Database {
.map(|language_server| proto::LanguageServer {
id: language_server.id as u64,
name: language_server.name,
worktree_id: None,
})
.collect::<Vec<_>>();

View File

@@ -16,12 +16,12 @@ pub struct Model {
pub mtime_seconds: i64,
pub mtime_nanos: i32,
pub git_status: Option<i64>,
pub is_symlink: bool,
pub is_ignored: bool,
pub is_external: bool,
pub is_deleted: bool,
pub scan_id: i64,
pub is_fifo: bool,
pub canonical_path: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@@ -10,7 +10,6 @@ pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
pub mod user_backfiller;
#[cfg(test)]
@@ -25,14 +24,11 @@ use axum::{
pub use cents::*;
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
use crate::stripe_billing::StripeBilling;
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error {
@@ -180,6 +176,8 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_llm_access_price_id: Option<Arc<str>>,
pub stripe_llm_usage_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -198,6 +196,10 @@ impl Config {
}
}
pub fn is_llm_billing_enabled(&self) -> bool {
self.stripe_llm_usage_price_id.is_some()
}
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -236,6 +238,8 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_llm_access_price_id: None,
stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
}
@@ -268,11 +272,9 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
pub llm_db: Option<Arc<LlmDatabase>>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub clickhouse_client: Option<::clickhouse::Client>,
@@ -286,20 +288,6 @@ impl AppState {
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
.llm_database_url
.clone()
.zip(config.llm_database_max_connections)
{
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
llm_db_options.max_connections(llm_database_max_connections);
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
llm_db.initialize().await?;
Some(Arc::new(llm_db))
} else {
None
};
let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server
.as_ref()
@@ -316,16 +304,11 @@ impl AppState {
};
let db = Arc::new(db);
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
let this = Self {
db: db.clone(),
llm_db,
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
@@ -338,11 +321,12 @@ impl AppState {
}
}
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config
.stripe_api_key
.as_ref()
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
Ok(stripe::Client::new(api_key))
}

View File

@@ -20,14 +20,13 @@ use axum::{
};
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
use isahc_http_client::IsahcHttpClient;
use rpc::ListModelsResponse;
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use std::{
pin::Pin,
sync::Arc,
@@ -44,7 +43,7 @@ pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient,
pub http_client: IsahcHttpClient,
pub clickhouse_client: Option<clickhouse::Client>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
@@ -70,8 +69,11 @@ impl LlmState {
let db = Arc::new(db);
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client =
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
.build()
.map(IsahcHttpClient::from)
.context("failed to construct http client")?;
let this = Self {
executor,
@@ -416,7 +418,10 @@ async fn perform_completion(
claims,
provider: params.provider,
model,
tokens: TokenUsage::default(),
input_tokens: 0,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
inner_stream: stream,
})))
}
@@ -435,7 +440,7 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5);
/// The default value to use for maximum spend per month if the user did not
/// explicitly set a maximum spend.
@@ -443,6 +448,9 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
/// Used to prevent surprise bills.
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
/// The maximum lifetime spending an individual user can reach before being cut off.
const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);
async fn check_usage_limit(
state: &Arc<LlmState>,
provider: LanguageModelProvider,
@@ -460,28 +468,23 @@ async fn check_usage_limit(
)
.await?;
if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
if !claims.has_llm_subscription {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"Maximum spending limit reached for this month.".to_string(),
));
if state.config.is_llm_billing_enabled() {
if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
if !claims.has_llm_subscription {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"Maximum spending limit reached for this month.".to_string(),
));
}
}
}
if (usage.spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT)
>= Cents(claims.max_monthly_spend_in_cents)
{
return Err(Error::Http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached for this month.".to_string(),
[(
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
));
}
// TODO: Remove this once we've rolled out monthly spending limits.
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT {
return Err(Error::http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached.".to_string(),
));
}
let active_users = state.get_active_user_count(provider, model_name).await?;
@@ -595,7 +598,10 @@ struct TokenCountingStream<S> {
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
tokens: TokenUsage,
input_tokens: usize,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
inner_stream: S,
}
@@ -609,10 +615,10 @@ where
match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n');
self.tokens.input += chunk.input_tokens;
self.tokens.output += chunk.output_tokens;
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
self.input_tokens += chunk.input_tokens;
self.output_tokens += chunk.output_tokens;
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
@@ -628,7 +634,10 @@ impl<S> Drop for TokenCountingStream<S> {
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let tokens = self.tokens;
let input_token_count = self.input_tokens;
let output_token_count = self.output_tokens;
let cache_creation_input_token_count = self.cache_creation_input_tokens;
let cache_read_input_token_count = self.cache_read_input_tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
@@ -637,9 +646,10 @@ impl<S> Drop for TokenCountingStream<S> {
claims.is_staff,
provider,
&model,
tokens,
claims.has_llm_subscription,
Cents(claims.max_monthly_spend_in_cents),
input_token_count,
cache_creation_input_token_count,
cache_read_input_token_count,
output_token_count,
Utc::now(),
)
.await
@@ -669,23 +679,22 @@ impl<S> Drop for TokenCountingStream<S> {
},
model,
provider: provider.to_string(),
input_token_count: tokens.input as u64,
cache_creation_input_token_count: tokens.input_cache_creation as u64,
cache_read_input_token_count: tokens.input_cache_read as u64,
output_token_count: tokens.output as u64,
input_token_count: input_token_count as u64,
cache_creation_input_token_count: cache_creation_input_token_count
as u64,
cache_read_input_token_count: cache_read_input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.tokens_this_month.input as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64,
cache_creation_input_tokens_this_month: usage
.tokens_this_month
.input_cache_creation
.cache_creation_input_tokens_this_month
as u64,
cache_read_input_tokens_this_month: usage
.tokens_this_month
.input_cache_read
.cache_read_input_tokens_this_month
as u64,
output_tokens_this_month: usage.tokens_this_month.output as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64,
spending_this_month: usage.spending_this_month.0 as u64,
lifetime_spending: usage.lifetime_spending.0 as u64,
},

View File

@@ -20,7 +20,7 @@ use std::future::Future;
use std::sync::Arc;
use anyhow::anyhow;
pub use queries::usages::{ActiveUserCount, TokenUsage};
pub use queries::usages::ActiveUserCount;
use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions;
use sea_orm::{

View File

@@ -3,9 +3,8 @@ use serde::{Deserialize, Serialize};
use crate::id_type;
id_type!(BillingEventId);
id_type!(ModelId);
id_type!(ProviderId);
id_type!(RevokedAccessTokenId);
id_type!(UsageId);
id_type!(UsageMeasureId);
id_type!(RevokedAccessTokenId);

View File

@@ -1,6 +1,5 @@
use super::*;
pub mod billing_events;
pub mod providers;
pub mod revoked_access_tokens;
pub mod usages;

View File

@@ -1,31 +0,0 @@
use super::*;
use crate::Result;
use anyhow::Context as _;
impl LlmDatabase {
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
self.transaction(|tx| async move {
let events_with_models = billing_event::Entity::find()
.find_also_related(model::Entity)
.all(&*tx)
.await?;
events_with_models
.into_iter()
.map(|(event, model)| {
let model =
model.context("could not find model associated with billing event")?;
Ok((event, model))
})
.collect()
})
.await
}
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
self.transaction(|tx| async move {
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
Ok(())
})
.await
}
}

View File

@@ -1,5 +1,5 @@
use crate::db::UserId;
use crate::llm::Cents;
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
use chrono::{Datelike, Duration};
use futures::StreamExt as _;
use rpc::LanguageModelProvider;
@@ -9,26 +9,15 @@ use strum::IntoEnumIterator as _;
use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
pub input_cache_read: usize,
pub output: usize,
}
impl TokenUsage {
pub fn total(&self) -> usize {
self.input + self.input_cache_creation + self.input_cache_read + self.output
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
pub tokens_this_day: usize,
pub tokens_this_month: TokenUsage,
pub input_tokens_this_month: usize,
pub cache_creation_input_tokens_this_month: usize,
pub cache_read_input_tokens_this_month: usize,
pub output_tokens_this_month: usize,
pub spending_this_month: Cents,
pub lifetime_spending: Cents,
}
@@ -268,20 +257,18 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage
.as_ref()
.map_or(0, |usage| usage.input_tokens as usize),
input_cache_creation: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
input_cache_read: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
output: monthly_usage
.as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
},
input_tokens_this_month: monthly_usage
.as_ref()
.map_or(0, |usage| usage.input_tokens as usize),
cache_creation_input_tokens_this_month: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
cache_read_input_tokens_this_month: monthly_usage
.as_ref()
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
output_tokens_this_month: monthly_usage
.as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
spending_this_month,
lifetime_spending,
})
@@ -296,9 +283,10 @@ impl LlmDatabase {
is_staff: bool,
provider: LanguageModelProvider,
model_name: &str,
tokens: TokenUsage,
has_llm_subscription: bool,
max_monthly_spend: Cents,
input_token_count: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
output_token_count: usize,
now: DateTimeUtc,
) -> Result<Usage> {
self.transaction(|tx| async move {
@@ -325,6 +313,10 @@ impl LlmDatabase {
&tx,
)
.await?;
let total_token_count = input_token_count
+ cache_read_input_tokens
+ cache_creation_input_tokens
+ output_token_count;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
@@ -333,7 +325,7 @@ impl LlmDatabase {
&usages,
UsageMeasure::TokensPerMinute,
now,
tokens.total(),
total_token_count,
&tx,
)
.await?;
@@ -345,7 +337,7 @@ impl LlmDatabase {
&usages,
UsageMeasure::TokensPerDay,
now,
tokens.total(),
total_token_count,
&tx,
)
.await?;
@@ -369,14 +361,18 @@ impl LlmDatabase {
Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
input_tokens: ActiveValue::set(
usage.input_tokens + input_token_count as i64,
),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
),
output_tokens: ActiveValue::set(
usage.output_tokens + output_token_count as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
@@ -388,12 +384,12 @@ impl LlmDatabase {
model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month),
year: ActiveValue::set(year),
input_tokens: ActiveValue::set(tokens.input as i64),
input_tokens: ActiveValue::set(input_token_count as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
cache_creation_input_tokens as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
output_tokens: ActiveValue::set(output_token_count as i64),
..Default::default()
}
.insert(&*tx)
@@ -409,27 +405,6 @@ impl LlmDatabase {
monthly_usage.output_tokens as usize,
);
if !is_staff
&& spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
&& has_llm_subscription
&& (spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT) <= max_monthly_spend
{
billing_event::ActiveModel {
id: ActiveValue::not_set(),
idempotency_key: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_cache_creation_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
}
.insert(&*tx)
.await?;
}
// Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find()
.filter(
@@ -444,14 +419,18 @@ impl LlmDatabase {
Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
input_tokens: ActiveValue::set(
usage.input_tokens + input_token_count as i64,
),
cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
),
cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
),
output_tokens: ActiveValue::set(
usage.output_tokens + output_token_count as i64,
),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default()
})
.exec(&*tx)
@@ -461,12 +440,12 @@ impl LlmDatabase {
lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_tokens: ActiveValue::set(input_token_count as i64),
cache_creation_input_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
cache_creation_input_tokens as i64,
),
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
output_tokens: ActiveValue::set(output_token_count as i64),
..Default::default()
}
.insert(&*tx)
@@ -486,12 +465,11 @@ impl LlmDatabase {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
tokens_this_month: TokenUsage {
input: monthly_usage.input_tokens as usize,
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
output: monthly_usage.output_tokens as usize,
},
input_tokens_this_month: monthly_usage.input_tokens as usize,
cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
as usize,
cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
output_tokens_this_month: monthly_usage.output_tokens as usize,
spending_this_month,
lifetime_spending,
})

View File

@@ -1,4 +1,3 @@
pub mod billing_event;
pub mod lifetime_usage;
pub mod model;
pub mod monthly_usage;

View File

@@ -1,37 +0,0 @@
use crate::{
db::UserId,
llm::db::{BillingEventId, ModelId},
};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_events")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingEventId,
pub idempotency_key: Uuid,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub input_cache_creation_tokens: i64,
pub input_cache_read_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -29,8 +29,6 @@ pub enum Relation {
Provider,
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
#[sea_orm(has_many = "super::billing_event::Entity")]
BillingEvents,
}
impl Related<super::provider::Entity> for Entity {
@@ -45,10 +43,4 @@ impl Related<super::usage::Entity> for Entity {
}
}
impl Related<super::billing_event::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingEvents.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,4 +1,3 @@
mod billing_tests;
mod provider_tests;
mod usage_tests;

View File

@@ -1,148 +0,0 @@
use crate::{
db::UserId,
llm::{
db::{queries::providers::ModelParams, LlmDatabase, TokenUsage},
FREE_TIER_MONTHLY_SPENDING_LIMIT,
},
test_llm_db, Cents,
};
use chrono::{DateTime, Utc};
use pretty_assertions::assert_eq;
use rpc::LanguageModelProvider;
test_llm_db!(
test_billing_limit_exceeded,
test_billing_limit_exceeded_postgres
);
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
let provider = LanguageModelProvider::Anthropic;
let model = "fake-claude-limerick";
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
// Initialize the database and insert the model
db.initialize().await.unwrap();
db.insert_models(&[ModelParams {
provider,
name: model.to_string(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 50_000,
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
}])
.await
.unwrap();
// Set a fixed datetime for consistent testing
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
.unwrap()
.with_timezone(&Utc);
let user_id = UserId::from_proto(123);
let max_monthly_spend = Cents::from_dollars(11);
// Record usage that brings us close to the limit but doesn't exceed it
// Let's say we use $10.50 worth of tokens
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
let usage = TokenUsage {
input: tokens_to_use,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// Verify that before we record any usage, there are 0 billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 0);
db.record_usage(
user_id,
false,
provider,
model,
usage,
true,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the recorded usage and spending
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
// Verify that we exceeded the free tier usage
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
// Verify that there is one `billing_event` record
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 1);
let (billing_event, _model) = &billing_events[0];
assert_eq!(billing_event.user_id, user_id);
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
assert_eq!(billing_event.input_cache_creation_tokens, 0);
assert_eq!(billing_event.input_cache_read_tokens, 0);
assert_eq!(billing_event.output_tokens, 0);
// Record usage that puts us at $20.50
let usage_2 = TokenUsage {
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
db.record_usage(
user_id,
false,
provider,
model,
usage_2,
true,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
// Verify that there are now two billing events
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
let usage_exceeding = TokenUsage {
input: tokens_to_exceed,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
};
// This should still create a billing event as it's the first request that exceeds the limit
db.record_usage(
user_id,
false,
provider,
model,
usage_exceeding,
true,
max_monthly_spend,
now,
)
.await
.unwrap();
// Verify the updated usage and spending
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
// Verify that we never exceed the user max spending for the user
// and avoid charging them.
let billing_events = db.get_billing_events().await.unwrap();
assert_eq!(billing_events.len(), 2);
}

View File

@@ -2,7 +2,7 @@ use crate::{
db::UserId,
llm::db::{
queries::{providers::ModelParams, usages::Usage},
LlmDatabase, TokenUsage,
LlmDatabase,
},
test_llm_db, Cents,
};
@@ -36,42 +36,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
let user_id = UserId::from_proto(123);
let now = t0;
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
.await
.unwrap();
let now = t0 + Duration::seconds(10);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 2000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -80,12 +52,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 3000,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 3000,
cache_creation_input_tokens_this_month: 0,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -99,35 +69,19 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 2000,
tokens_this_day: 3000,
tokens_this_month: TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 3000,
cache_creation_input_tokens_this_month: 0,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
let now = t0 + Duration::seconds(60);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -136,12 +90,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 5000,
tokens_this_day: 6000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 6000,
cache_creation_input_tokens_this_month: 0,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -156,34 +108,18 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 0,
tokens_this_minute: 0,
tokens_this_day: 5000,
tokens_this_month: TokenUsage {
input: 6000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 6000,
cache_creation_input_tokens_this_month: 0,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 4000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -192,12 +128,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 4000,
tokens_this_day: 9000,
tokens_this_month: TokenUsage {
input: 10000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 10000,
cache_creation_input_tokens_this_month: 0,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
@@ -209,23 +143,9 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
.with_timezone(&Utc);
// Test cache creation input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -234,35 +154,19 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1,
tokens_this_minute: 1500,
tokens_this_day: 1500,
tokens_this_month: TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
input_tokens_this_month: 1000,
cache_creation_input_tokens_this_month: 500,
cache_read_input_tokens_this_month: 0,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}
);
// Test cache read input tokens
db.record_usage(
user_id,
false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 300,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!(
@@ -271,12 +175,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2,
tokens_this_minute: 2800,
tokens_this_day: 2800,
tokens_this_month: TokenUsage {
input: 2000,
input_cache_creation: 500,
input_cache_read: 300,
output: 0,
},
input_tokens_this_month: 2000,
cache_creation_input_tokens_this_month: 500,
cache_read_input_tokens_this_month: 300,
output_tokens_this_month: 0,
spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO,
}

View File

@@ -111,13 +111,6 @@ async fn main() -> Result<()> {
let state = AppState::new(config, Executor::Production).await?;
if let Some(stripe_billing) = state.stripe_billing.clone() {
let executor = state.executor.clone();
executor.spawn_detached(async move {
stripe_billing.initialize().await.trace_err();
});
}
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(
@@ -132,8 +125,6 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));
@@ -142,6 +133,7 @@ async fn main() -> Result<()> {
}
if mode.is_api() {
poll_stripe_events_periodically(state.clone());
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());
@@ -165,7 +157,7 @@ async fn main() -> Result<()> {
if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?;
sync_llm_usage_with_stripe_periodically(state.clone());
sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
}
app = app

View File

@@ -36,8 +36,8 @@ use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use http_client::HttpClient;
use isahc_http_client::IsahcHttpClient;
use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
use reqwest_client::ReqwestClient;
use sha2::Digest;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
@@ -961,8 +961,8 @@ impl Server {
tracing::info!("connection opened");
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = match ReqwestClient::user_agent(&user_agent) {
Ok(http_client) => Arc::new(http_client),
let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
Err(error) => {
tracing::error!(?error, "failed to create HTTP client");
return;
@@ -1218,15 +1218,6 @@ impl Server {
Ok(())
}
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, proto::RefreshLlmToken {})
.trace_err();
}
}
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot {
connection_pool: ConnectionPoolGuard {
@@ -2237,7 +2228,7 @@ fn join_project_internal(
worktree_id: worktree.id,
path: settings_file.path,
content: Some(settings_file.content),
kind: Some(settings_file.kind.to_proto() as i32),
kind: Some(proto::update_user_settings::Kind::Settings.into()),
},
)?;
}

View File

@@ -1,479 +0,0 @@
use std::sync::Arc;
use crate::{llm, Cents, Result};
use anyhow::Context;
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
}
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
}
pub struct StripeModel {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
output_tokens_price: StripeBillingPrice,
}
struct StripeBillingPrice {
id: stripe::PriceId,
meter_event_name: String,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client,
state: RwLock::default(),
}
}
pub async fn initialize(&self) -> Result<()> {
log::info!("StripeBilling: initializing");
let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.client),
stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
}
)
)?;
for meter in meters.data {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
}
for price in prices.data {
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
state.price_ids_by_meter_id.insert(meter, price.id);
}
}
}
log::info!("StripeBilling: initialized");
Ok(())
}
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
let input_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
.get_or_insert_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
Ok(StripeModel {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
output_tokens_price,
})
}
async fn get_or_insert_price(
&self,
meter_event_name: &str,
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
// Fast code path when the meter and the price already exist.
{
let state = self.state.read().await;
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
return Ok(StripeBillingPrice {
id: price_id.clone(),
meter_event_name: meter_event_name.to_string(),
});
}
}
}
let mut state = self.state.write().await;
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
meter.clone()
} else {
let meter = StripeMeter::create(
&self.client,
StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
state
.meters_by_event_name
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true),
metadata: None,
name: price_description.to_string(),
statement_descriptor: None,
tax_code: None,
unit_label: None,
}),
recurring: Some(stripe::CreatePriceRecurring {
aggregate_usage: None,
interval: stripe::CreatePriceRecurringInterval::Month,
interval_count: None,
trial_period_days: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
meter: Some(meter.id.clone()),
}),
tax_behavior: None,
tiers: None,
tiers_mode: None,
transfer_lookup_key: None,
transform_quantity: None,
unit_amount: None,
unit_amount_decimal: Some(&format!(
"{:.12}",
price_per_million_tokens.0 as f64 / 1_000_000f64
)),
},
)
.await?;
state
.price_ids_by_meter_id
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice {
id: price_id,
meter_event_name: meter_event_name.to_string(),
})
}
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
model: &StripeModel,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
let mut items = Vec::new();
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
{
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_read_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.output_tokens_price.id.to_string()),
..Default::default()
});
}
if !items.is_empty() {
items.extend(subscription.items.data.iter().map(|item| {
stripe::UpdateSubscriptionItems {
id: Some(item.id.to_string()),
..Default::default()
}
}));
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(items),
..Default::default()
},
)
.await?;
}
Ok(())
}
pub async fn bill_model_usage(
&self,
customer_id: &stripe::CustomerId,
model: &StripeModel,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
if event.input_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_tokens/{}", event.idempotency_key),
event_name: &model.input_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_creation_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_creation_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_read_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_read_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_read_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.output_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("output_tokens/{}", event.idempotency_key),
event_name: &model.output_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.output_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
Ok(())
}
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
model: &StripeModel,
success_url: &str,
) -> Result<String> {
let first_of_next_month = Utc::now()
.checked_add_months(chrono::Months::new(1))
.unwrap()
.with_day(1)
.unwrap();
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
billing_cycle_anchor: Some(first_of_next_month.timestamp()),
..Default::default()
});
params.line_items = Some(
[
&model.input_tokens_price.id,
&model.input_cache_creation_tokens_price.id,
&model.input_cache_read_tokens_price.id,
&model.output_tokens_price.id,
]
.into_iter()
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
price: Some(price_id.to_string()),
..Default::default()
})
.collect(),
);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Serialize)]
struct DefaultAggregation {
formula: &'static str,
}
#[derive(Serialize)]
struct StripeCreateMeterParams<'a> {
default_aggregation: DefaultAggregation,
display_name: String,
event_name: &'a str,
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
event_name: String,
}
impl StripeMeter {
pub fn create(
client: &stripe::Client,
params: StripeCreateMeterParams,
) -> stripe::Response<Self> {
client.post_form("/billing/meters", params)
}
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
client.get_query("/billing/meters", Params { limit: Some(100) })
}
}
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,
}
impl StripeMeterEvent {
pub async fn create(
client: &stripe::Client,
params: StripeCreateMeterEventParams<'_>,
) -> Result<Self, stripe::StripeError> {
let identifier = params.identifier;
match client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(Self {
identifier: identifier.to_string(),
})
} else {
Err(stripe::StripeError::Stripe(error))
}
}
Err(error) => Err(error),
}
}
}
#[derive(Serialize)]
struct StripeCreateMeterEventParams<'a> {
identifier: &'a str,
event_name: &'a str,
payload: StripeCreateMeterEventPayload<'a>,
timestamp: Option<i64>,
}
#[derive(Serialize)]
struct StripeCreateMeterEventPayload<'a> {
value: u64,
stripe_customer_id: &'a stripe::CustomerId,
}
fn subscription_contains_price(
subscription: &stripe::Subscription,
price_id: &stripe::PriceId,
) -> bool {
subscription.items.data.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)
})
}

View File

@@ -12,7 +12,6 @@ use editor::{
test::editor_test_context::{AssertionContextManager, EditorTestContext},
Editor,
};
use fs::Fs;
use futures::StreamExt;
use gpui::{TestAppContext, UpdateGlobal, VisualContext, VisualTestContext};
use indoc::indoc;
@@ -31,7 +30,7 @@ use serde_json::json;
use settings::SettingsStore;
use std::{
ops::Range,
path::{Path, PathBuf},
path::Path,
sync::{
atomic::{self, AtomicBool, AtomicUsize},
Arc,
@@ -61,7 +60,7 @@ async fn test_host_disconnect(
.fs()
.insert_tree(
"/a",
json!({
serde_json::json!({
"a.txt": "a-contents",
"b.txt": "b-contents",
}),
@@ -2153,295 +2152,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
});
}
#[gpui::test(iterations = 30)]
async fn test_collaborating_with_editorconfig(
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(cx_a.executor()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
.create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
.await;
let active_call_a = cx_a.read(ActiveCall::global);
cx_b.update(editor::init);
// Set up a fake language server.
client_a.language_registry().add(rust_lang());
client_a
.fs()
.insert_tree(
"/a",
json!({
"src": {
"main.rs": "mod other;\nfn main() { let foo = other::foo(); }",
"other_mod": {
"other.rs": "pub fn foo() -> usize {\n 4\n}",
".editorconfig": "",
},
},
".editorconfig": "[*]\ntab_width = 2\n",
}),
)
.await;
let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
.unwrap();
let main_buffer_a = project_a
.update(cx_a, |p, cx| {
p.open_buffer((worktree_id, "src/main.rs"), cx)
})
.await
.unwrap();
let other_buffer_a = project_a
.update(cx_a, |p, cx| {
p.open_buffer((worktree_id, "src/other_mod/other.rs"), cx)
})
.await
.unwrap();
let cx_a = cx_a.add_empty_window();
let main_editor_a =
cx_a.new_view(|cx| Editor::for_buffer(main_buffer_a, Some(project_a.clone()), cx));
let other_editor_a =
cx_a.new_view(|cx| Editor::for_buffer(other_buffer_a, Some(project_a), cx));
let mut main_editor_cx_a = EditorTestContext {
cx: cx_a.clone(),
window: cx_a.handle(),
editor: main_editor_a,
assertion_cx: AssertionContextManager::new(),
};
let mut other_editor_cx_a = EditorTestContext {
cx: cx_a.clone(),
window: cx_a.handle(),
editor: other_editor_a,
assertion_cx: AssertionContextManager::new(),
};
// Join the project as client B.
let project_b = client_b.join_remote_project(project_id, cx_b).await;
let main_buffer_b = project_b
.update(cx_b, |p, cx| {
p.open_buffer((worktree_id, "src/main.rs"), cx)
})
.await
.unwrap();
let other_buffer_b = project_b
.update(cx_b, |p, cx| {
p.open_buffer((worktree_id, "src/other_mod/other.rs"), cx)
})
.await
.unwrap();
let cx_b = cx_b.add_empty_window();
let main_editor_b =
cx_b.new_view(|cx| Editor::for_buffer(main_buffer_b, Some(project_b.clone()), cx));
let other_editor_b =
cx_b.new_view(|cx| Editor::for_buffer(other_buffer_b, Some(project_b.clone()), cx));
let mut main_editor_cx_b = EditorTestContext {
cx: cx_b.clone(),
window: cx_b.handle(),
editor: main_editor_b,
assertion_cx: AssertionContextManager::new(),
};
let mut other_editor_cx_b = EditorTestContext {
cx: cx_b.clone(),
window: cx_b.handle(),
editor: other_editor_b,
assertion_cx: AssertionContextManager::new(),
};
let initial_main = indoc! {"
ˇmod other;
fn main() { let foo = other::foo(); }"};
let initial_other = indoc! {"
ˇpub fn foo() -> usize {
4
}"};
let first_tabbed_main = indoc! {"
ˇmod other;
fn main() { let foo = other::foo(); }"};
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
first_tabbed_main,
true,
);
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
first_tabbed_main,
false,
);
let first_tabbed_other = indoc! {"
ˇpub fn foo() -> usize {
4
}"};
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
first_tabbed_other,
true,
);
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
first_tabbed_other,
false,
);
client_a
.fs()
.atomic_write(
PathBuf::from("/a/src/.editorconfig"),
"[*]\ntab_width = 3\n".to_owned(),
)
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
let second_tabbed_main = indoc! {"
ˇmod other;
fn main() { let foo = other::foo(); }"};
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
second_tabbed_main,
true,
);
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
second_tabbed_main,
false,
);
let second_tabbed_other = indoc! {"
ˇpub fn foo() -> usize {
4
}"};
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
second_tabbed_other,
true,
);
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
second_tabbed_other,
false,
);
let editorconfig_buffer_b = project_b
.update(cx_b, |p, cx| {
p.open_buffer((worktree_id, "src/other_mod/.editorconfig"), cx)
})
.await
.unwrap();
editorconfig_buffer_b.update(cx_b, |buffer, cx| {
buffer.set_text("[*.rs]\ntab_width = 6\n", cx);
});
project_b
.update(cx_b, |project, cx| {
project.save_buffer(editorconfig_buffer_b.clone(), cx)
})
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
second_tabbed_main,
true,
);
tab_undo_assert(
&mut main_editor_cx_a,
&mut main_editor_cx_b,
initial_main,
second_tabbed_main,
false,
);
let third_tabbed_other = indoc! {"
ˇpub fn foo() -> usize {
4
}"};
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
third_tabbed_other,
true,
);
tab_undo_assert(
&mut other_editor_cx_a,
&mut other_editor_cx_b,
initial_other,
third_tabbed_other,
false,
);
}
#[track_caller]
fn tab_undo_assert(
cx_a: &mut EditorTestContext,
cx_b: &mut EditorTestContext,
expected_initial: &str,
expected_tabbed: &str,
a_tabs: bool,
) {
cx_a.assert_editor_state(expected_initial);
cx_b.assert_editor_state(expected_initial);
if a_tabs {
cx_a.update_editor(|editor, cx| {
editor.tab(&editor::actions::Tab, cx);
});
} else {
cx_b.update_editor(|editor, cx| {
editor.tab(&editor::actions::Tab, cx);
});
}
cx_a.run_until_parked();
cx_b.run_until_parked();
cx_a.assert_editor_state(expected_tabbed);
cx_b.assert_editor_state(expected_tabbed);
if a_tabs {
cx_a.update_editor(|editor, cx| {
editor.undo(&editor::actions::Undo, cx);
});
} else {
cx_b.update_editor(|editor, cx| {
editor.undo(&editor::actions::Undo, cx);
});
}
cx_a.run_until_parked();
cx_b.run_until_parked();
cx_a.assert_editor_state(expected_initial);
cx_b.assert_editor_state(expected_initial);
}
fn extract_hint_labels(editor: &Editor) -> Vec<String> {
let mut labels = Vec::new();
for hint in editor.inlay_hint_cache().hints() {

View File

@@ -27,14 +27,13 @@ use language::{
use live_kit_client::MacOSDisplay;
use lsp::LanguageServerId;
use parking_lot::Mutex;
use project::lsp_store::FormatTarget;
use project::{
lsp_store::FormatTrigger, search::SearchQuery, search::SearchResult, DiagnosticSummary,
HoverBlockKind, Project, ProjectPath,
};
use rand::prelude::*;
use serde_json::json;
use settings::SettingsStore;
use settings::{LocalSettingsKind, SettingsStore};
use std::{
cell::{Cell, RefCell},
env, future, mem,
@@ -3328,8 +3327,16 @@ async fn test_local_settings(
.local_settings(worktree_b.read(cx).id())
.collect::<Vec<_>>(),
&[
(Path::new("").into(), r#"{"tab_size":2}"#.to_string()),
(Path::new("a").into(), r#"{"tab_size":8}"#.to_string()),
(
Path::new("").into(),
LocalSettingsKind::Settings,
r#"{"tab_size":2}"#.to_string()
),
(
Path::new("a").into(),
LocalSettingsKind::Settings,
r#"{"tab_size":8}"#.to_string()
),
]
)
});
@@ -3347,8 +3354,16 @@ async fn test_local_settings(
.local_settings(worktree_b.read(cx).id())
.collect::<Vec<_>>(),
&[
(Path::new("").into(), r#"{}"#.to_string()),
(Path::new("a").into(), r#"{"tab_size":8}"#.to_string()),
(
Path::new("").into(),
LocalSettingsKind::Settings,
r#"{}"#.to_string()
),
(
Path::new("a").into(),
LocalSettingsKind::Settings,
r#"{"tab_size":8}"#.to_string()
),
]
)
});
@@ -3376,8 +3391,16 @@ async fn test_local_settings(
.local_settings(worktree_b.read(cx).id())
.collect::<Vec<_>>(),
&[
(Path::new("a").into(), r#"{"tab_size":8}"#.to_string()),
(Path::new("b").into(), r#"{"tab_size":4}"#.to_string()),
(
Path::new("a").into(),
LocalSettingsKind::Settings,
r#"{"tab_size":8}"#.to_string()
),
(
Path::new("b").into(),
LocalSettingsKind::Settings,
r#"{"tab_size":4}"#.to_string()
),
]
)
});
@@ -3407,7 +3430,11 @@ async fn test_local_settings(
store
.local_settings(worktree_b.read(cx).id())
.collect::<Vec<_>>(),
&[(Path::new("a").into(), r#"{"hard_tabs":true}"#.to_string()),]
&[(
Path::new("a").into(),
LocalSettingsKind::Settings,
r#"{"hard_tabs":true}"#.to_string()
),]
)
});
}
@@ -4390,7 +4417,6 @@ async fn test_formatting_buffer(
HashSet::from_iter([buffer_b.clone()]),
true,
FormatTrigger::Save,
FormatTarget::Buffer,
cx,
)
})
@@ -4424,7 +4450,6 @@ async fn test_formatting_buffer(
HashSet::from_iter([buffer_b.clone()]),
true,
FormatTrigger::Save,
FormatTarget::Buffer,
cx,
)
})
@@ -4530,7 +4555,6 @@ async fn test_prettier_formatting_buffer(
HashSet::from_iter([buffer_b.clone()]),
true,
FormatTrigger::Save,
FormatTarget::Buffer,
cx,
)
})
@@ -4550,7 +4574,6 @@ async fn test_prettier_formatting_buffer(
HashSet::from_iter([buffer_a.clone()]),
true,
FormatTrigger::Manual,
FormatTarget::Buffer,
cx,
)
})

View File

@@ -2,12 +2,10 @@ use crate::tests::TestServer;
use call::ActiveCall;
use fs::{FakeFs, Fs as _};
use gpui::{Context as _, TestAppContext};
use http_client::BlockedHttpClient;
use language::{language_settings::language_settings, LanguageRegistry};
use node_runtime::NodeRuntime;
use language::language_settings::all_language_settings;
use project::ProjectPath;
use remote::SshRemoteClient;
use remote_server::{HeadlessAppState, HeadlessProject};
use remote_server::HeadlessProject;
use serde_json::json;
use std::{path::Path, sync::Arc};
@@ -26,7 +24,7 @@ async fn test_sharing_an_ssh_remote_project(
.await;
// Set up project on remote FS
let (port, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
let (client_ssh, server_ssh) = SshRemoteClient::fake(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
@@ -50,24 +48,9 @@ async fn test_sharing_an_ssh_remote_project(
// User A connects to the remote project via SSH.
server_cx.update(HeadlessProject::init);
let remote_http_client = Arc::new(BlockedHttpClient);
let node = NodeRuntime::unavailable();
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
let _headless_project = server_cx.new_model(|cx| {
client::init_settings(cx);
HeadlessProject::new(
HeadlessAppState {
session: server_ssh,
fs: remote_fs.clone(),
http_client: remote_http_client,
node_runtime: node,
languages,
},
cx,
)
});
let _headless_project =
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
let client_ssh = SshRemoteClient::fake_client(port, cx_a).await;
let (project_a, worktree_id) = client_a
.build_ssh_project("/code/project1", client_ssh, cx_a)
.await;
@@ -135,7 +118,9 @@ async fn test_sharing_an_ssh_remote_project(
cx_b.read(|cx| {
let file = buffer_b.read(cx).file();
assert_eq!(
language_settings(Some("Rust".into()), file, cx).language_servers,
all_language_settings(file, cx)
.language(Some(&("Rust".into())))
.language_servers,
["override-rust-analyzer".to_string()]
)
});

View File

@@ -635,11 +635,9 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
llm_db: None,
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
stripe_client: None,
stripe_billing: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
clickhouse_client: None,
@@ -679,6 +677,8 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_llm_access_price_id: None,
stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
},

View File

@@ -26,7 +26,7 @@ const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
@@ -94,6 +94,7 @@ enum CspResult<T> {
#[derive(Serialize, Deserialize)]
struct Notification<'a, T> {
jsonrpc: &'static str,
id: RequestId,
#[serde(borrow)]
method: &'a str,
params: T,
@@ -102,6 +103,7 @@ struct Notification<'a, T> {
#[derive(Debug, Clone, Deserialize)]
struct AnyNotification<'a> {
jsonrpc: &'a str,
id: RequestId,
method: String,
#[serde(default)]
params: Option<Value>,
@@ -244,7 +246,11 @@ impl Client {
if let Some(handler) =
notification_handlers.get_mut(notification.method.as_str())
{
handler(notification.params.unwrap_or(Value::Null), cx.clone());
handler(
notification.id,
notification.params.unwrap_or(Value::Null),
cx.clone(),
);
}
}
}
@@ -372,8 +378,10 @@ impl Client {
/// Sends a notification to the context server without expecting a response.
/// This function serializes the notification and sends it through the outbound channel.
pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
let id = self.next_id.fetch_add(1, SeqCst);
let notification = serde_json::to_string(&Notification {
jsonrpc: JSON_RPC_VERSION,
id: RequestId::Int(id),
method,
params,
})
@@ -382,13 +390,13 @@ impl Client {
Ok(())
}
pub fn on_notification<F>(&self, method: &'static str, f: F)
pub fn on_notification<F>(&self, method: &'static str, mut f: F)
where
F: 'static + Send + FnMut(Value, AsyncAppContext),
{
self.notification_handlers
.lock()
.insert(method, Box::new(f));
.insert(method, Box::new(move |_, params, cx| f(params, cx)));
}
pub fn name(&self) -> &str {

View File

@@ -85,7 +85,7 @@ impl ContextServer {
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
let client_info = types::EntityInfo {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};

View File

@@ -11,6 +11,8 @@ use collections::HashMap;
use crate::client::Client;
use crate::types;
pub use types::PromptInfo;
const PROTOCOL_VERSION: u32 = 1;
pub struct ModelContextProtocol {
@@ -24,7 +26,7 @@ impl ModelContextProtocol {
pub async fn initialize(
self,
client_info: types::Implementation,
client_info: types::EntityInfo,
) -> Result<InitializedContextServerProtocol> {
let params = types::InitializeParams {
protocol_version: PROTOCOL_VERSION,
@@ -94,7 +96,7 @@ impl InitializedContextServerProtocol {
}
/// List the MCP prompts.
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
self.check_capability(ServerCapability::Prompts)?;
let response: types::PromptsListResponse = self
@@ -105,18 +107,6 @@ impl InitializedContextServerProtocol {
Ok(response.prompts)
}
/// List the MCP resources.
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
self.check_capability(ServerCapability::Resources)?;
let response: types::ResourcesListResponse = self
.inner
.request(types::RequestType::ResourcesList.as_str(), ())
.await?;
Ok(response)
}
/// Executes a prompt with the given arguments and returns the result.
pub async fn run_prompt<P: AsRef<str>>(
&self,

View File

@@ -15,7 +15,6 @@ pub enum RequestType {
PromptsGet,
PromptsList,
CompletionComplete,
Ping,
}
impl RequestType {
@@ -31,7 +30,6 @@ impl RequestType {
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
RequestType::Ping => "ping",
}
}
}
@@ -41,15 +39,14 @@ impl RequestType {
pub struct InitializeParams {
pub protocol_version: u32,
pub capabilities: ClientCapabilities,
pub client_info: Implementation,
pub client_info: EntityInfo,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, serde_json::Value>>,
pub arguments: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
@@ -80,7 +77,6 @@ pub struct LoggingSetLevelParams {
#[serde(rename_all = "camelCase")]
pub struct PromptsGetParams {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<HashMap<String, String>>,
}
@@ -105,13 +101,6 @@ pub struct PromptReference {
pub name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceReference {
pub r#type: PromptReferenceType,
pub uri: Url,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PromptReferenceType {
@@ -121,6 +110,13 @@ pub enum PromptReferenceType {
Resource,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceReference {
pub r#type: String,
pub uri: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionArgument {
@@ -133,7 +129,7 @@ pub struct CompletionArgument {
pub struct InitializeResponse {
pub protocol_version: u32,
pub capabilities: ServerCapabilities,
pub server_info: Implementation,
pub server_info: EntityInfo,
}
#[derive(Debug, Deserialize)]
@@ -145,39 +141,13 @@ pub struct ResourcesReadResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub resource_templates: Option<Vec<ResourceTemplate>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<Vec<Resource>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingMessage {
pub role: SamplingRole,
pub content: SamplingContent,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SamplingRole {
User,
Assistant,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SamplingContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { data: String, mime_type: String },
pub resources: Vec<Resource>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub prompt: String,
}
@@ -185,7 +155,7 @@ pub struct PromptsGetResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
pub prompts: Vec<PromptInfo>,
}
#[derive(Debug, Deserialize)]
@@ -198,91 +168,61 @@ pub struct CompletionCompleteResponse {
#[serde(rename_all = "camelCase")]
pub struct CompletionResult {
pub values: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub has_more: Option<bool>,
}
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Prompt {
pub struct PromptInfo {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Vec<PromptArgument>>,
}
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PromptArgument {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
}
// Shared Types
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<serde_json::Value>,
pub sampling: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logging: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logging: Option<HashMap<String, serde_json::Value>>,
pub prompts: Option<HashMap<String, serde_json::Value>>,
pub resources: Option<ResourcesCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapabilities>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
pub tools: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub subscribe: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Implementation {
pub struct EntityInfo {
pub name: String,
pub version: String,
}
@@ -291,10 +231,6 @@ pub struct Implementation {
#[serde(rename_all = "camelCase")]
pub struct Resource {
pub uri: Url,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
@@ -302,23 +238,17 @@ pub struct Resource {
#[serde(rename_all = "camelCase")]
pub struct ResourceContent {
pub uri: Url,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<String>,
pub data: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceTemplate {
pub uri_template: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -330,16 +260,13 @@ pub enum LoggingLevel {
Error,
}
// Client Notifications
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
Progress,
Message,
ResourcesUpdated,
ResourcesListChanged,
ToolsListChanged,
PromptsListChanged,
}
impl NotificationType {
@@ -347,11 +274,6 @@ impl NotificationType {
match self {
NotificationType::Initialized => "notifications/initialized",
NotificationType::Progress => "notifications/progress",
NotificationType::Message => "notifications/message",
NotificationType::ResourcesUpdated => "notifications/resources/updated",
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
}
}
}
@@ -366,13 +288,12 @@ pub enum ClientNotification {
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ProgressParams {
pub progress_token: ProgressToken,
pub progress_token: String,
pub progress: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<f64>,
}
pub type ProgressToken = String;
// Helper Types that don't map directly to the protocol
pub enum CompletionTotal {
Exact(u32),

View File

@@ -864,11 +864,7 @@ impl Copilot {
let buffer = buffer.read(cx);
let uri = registered_buffer.uri.clone();
let position = position.to_point_utf16(buffer);
let settings = language_settings(
buffer.language_at(position).map(|l| l.name()),
buffer.file(),
cx,
);
let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx);
let tab_size = settings.tab_size;
let hard_tabs = settings.hard_tabs;
let relative_path = buffer

Some files were not shown because too many files have changed in this diff Show More