Compare commits
17 Commits
thorsten-g
...
v0.157.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6192aa1469 | ||
|
|
2b902c185e | ||
|
|
e2e95f2c49 | ||
|
|
5e3a02b3f3 | ||
|
|
bc768d8586 | ||
|
|
84caa0cf4c | ||
|
|
8445b4adfb | ||
|
|
86e2510414 | ||
|
|
5222a1162c | ||
|
|
be25c51c5b | ||
|
|
ef0eeb4853 | ||
|
|
4e0db8ba32 | ||
|
|
ed379fe233 | ||
|
|
eb933ce203 | ||
|
|
515f9a6c7d | ||
|
|
9c33d723f8 | ||
|
|
5b303e892a |
6
.github/actions/check_style/action.yml
vendored
6
.github/actions/check_style/action.yml
vendored
@@ -7,3 +7,9 @@ runs:
|
||||
- name: cargo fmt
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: cargo fmt --all -- --check
|
||||
|
||||
- name: Find modified migrations
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: |
|
||||
export SQUAWK_GITHUB_TOKEN=${{ github.token }}
|
||||
. ./script/squawk
|
||||
|
||||
12
.github/pull_request_template.md
vendored
12
.github/pull_request_template.md
vendored
@@ -2,4 +2,14 @@ Closes #ISSUE
|
||||
|
||||
Release Notes:
|
||||
|
||||
- N/A *or* Added/Fixed/Improved ...
|
||||
- Added/Fixed/Improved ...
|
||||
|
||||
Optionally, include screenshots / media showcasing your addition that can be included in the release notes.
|
||||
|
||||
### Or...
|
||||
|
||||
Closes #ISSUE
|
||||
|
||||
Release Notes:
|
||||
|
||||
- N/A
|
||||
|
||||
2
.github/workflows/bump_collab_staging.yml
vendored
2
.github/workflows/bump_collab_staging.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/bump_patch_version.yml
vendored
2
.github/workflows/bump_patch_version.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch }}
|
||||
ssh-key: ${{ secrets.ZED_BOT_DEPLOY_KEY }}
|
||||
|
||||
113
.github/workflows/ci.yml
vendored
113
.github/workflows/ci.yml
vendored
@@ -26,28 +26,36 @@ env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
migration_checks:
|
||||
name: Check Postgres and Protobuf migrations, mergability
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
style:
|
||||
timeout-minutes: 60
|
||||
name: Check formatting and spelling
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0 # fetch full history
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Remove untracked files
|
||||
run: git clean -df
|
||||
|
||||
- name: Find modified migrations
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: |
|
||||
export SQUAWK_GITHUB_TOKEN=${{ github.token }}
|
||||
. ./script/squawk
|
||||
- name: Check spelling
|
||||
run: script/check-spelling
|
||||
|
||||
- name: Run style checks
|
||||
uses: ./.github/actions/check_style
|
||||
|
||||
- name: Check unused dependencies
|
||||
uses: bnjbvr/cargo-machete@main
|
||||
|
||||
- name: Check licenses are present
|
||||
run: script/check-licenses
|
||||
|
||||
- name: Check license generation
|
||||
run: script/generate-licenses /tmp/zed_licenses_output
|
||||
|
||||
- name: Ensure fresh merge
|
||||
shell: bash -euxo pipefail {0}
|
||||
@@ -69,24 +77,6 @@ jobs:
|
||||
input: "crates/proto/proto/"
|
||||
against: "https://github.com/${GITHUB_REPOSITORY}.git#branch=${BUF_BASE_BRANCH},subdir=crates/proto/proto/"
|
||||
|
||||
style:
|
||||
timeout-minutes: 60
|
||||
name: Check formatting and spelling
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on:
|
||||
- buildjet-8vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- name: Run style checks
|
||||
uses: ./.github/actions/check_style
|
||||
|
||||
- name: Check for typos
|
||||
uses: crate-ci/typos@v1.24.6
|
||||
with:
|
||||
config: ./typos.toml
|
||||
|
||||
macos_tests:
|
||||
timeout-minutes: 60
|
||||
name: (macOS) Run Clippy and tests
|
||||
@@ -95,32 +85,21 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: cargo clippy
|
||||
run: ./script/clippy
|
||||
|
||||
- name: Check unused dependencies
|
||||
uses: bnjbvr/cargo-machete@main
|
||||
|
||||
- name: Check licenses
|
||||
run: |
|
||||
script/check-licenses
|
||||
script/generate-licenses /tmp/zed_licenses_output
|
||||
|
||||
- name: Run tests
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build collab
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p collab
|
||||
run: cargo build -p collab
|
||||
|
||||
- name: Build other binaries and features
|
||||
run: |
|
||||
RUSTFLAGS="-D warnings" cargo build --workspace --bins --all-features
|
||||
cargo check -p gpui --features "macos-blade"
|
||||
RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
run: cargo build --workspace --bins --all-features; cargo check -p gpui --features "macos-blade"
|
||||
|
||||
linux_tests:
|
||||
timeout-minutes: 60
|
||||
@@ -132,7 +111,7 @@ jobs:
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -152,33 +131,7 @@ jobs:
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build Zed
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p zed
|
||||
|
||||
build_remote_server:
|
||||
timeout-minutes: 60
|
||||
name: (Linux) Build Remote Server
|
||||
runs-on:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
cache-provider: "buildjet"
|
||||
|
||||
- name: Install Clang & Mold
|
||||
run: ./script/remote-server && ./script/install-mold 2.34.0
|
||||
|
||||
- name: Build Remote Server
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
run: cargo build -p zed
|
||||
|
||||
# todo(windows): Actually run the tests
|
||||
windows_tests:
|
||||
@@ -187,7 +140,7 @@ jobs:
|
||||
runs-on: hosted-windows-1
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -202,7 +155,7 @@ jobs:
|
||||
run: cargo xtask clippy
|
||||
|
||||
- name: Build Zed
|
||||
run: $env:RUSTFLAGS="-D warnings"; cargo build
|
||||
run: cargo build -p zed
|
||||
|
||||
bundle-mac:
|
||||
timeout-minutes: 60
|
||||
@@ -228,7 +181,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
# We need to fetch more than one commit so that `script/draft-release-notes`
|
||||
# is able to diff between the current and previous tag.
|
||||
@@ -266,20 +219,20 @@ jobs:
|
||||
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
|
||||
|
||||
- name: Upload app bundle (universal) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
|
||||
path: target/release/Zed.dmg
|
||||
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
|
||||
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
|
||||
|
||||
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
|
||||
@@ -313,7 +266,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -330,7 +283,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
|
||||
@@ -360,7 +313,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -377,7 +330,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
|
||||
|
||||
2
.github/workflows/danger.yml
vendored
2
.github/workflows/danger.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
|
||||
11
.github/workflows/deploy_cloudflare.yml
vendored
11
.github/workflows/deploy_cloudflare.yml
vendored
@@ -8,12 +8,11 @@ on:
|
||||
jobs:
|
||||
deploy-docs:
|
||||
name: Deploy Docs
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -37,28 +36,28 @@ jobs:
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Deploy Docs
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: pages deploy target/deploy --project-name=docs
|
||||
|
||||
- name: Deploy Install
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
|
||||
|
||||
- name: Deploy Docs Workers
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: deploy .cloudflare/docs-proxy/src/worker.js
|
||||
|
||||
- name: Deploy Install Workers
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
|
||||
8
.github/workflows/deploy_collab.yml
vendored
8
.github/workflows/deploy_collab.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
run: doctl registry login
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
10
.github/workflows/docs.yml
vendored
10
.github/workflows/docs.yml
vendored
@@ -11,11 +11,10 @@ on:
|
||||
jobs:
|
||||
check_formatting:
|
||||
name: "Check formatting"
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
@@ -30,8 +29,5 @@ jobs:
|
||||
false
|
||||
}
|
||||
|
||||
- name: Check for Typos with Typos-CLI
|
||||
uses: crate-ci/typos@v1.24.6
|
||||
with:
|
||||
config: ./typos.toml
|
||||
files: ./docs/
|
||||
- name: Check spelling
|
||||
run: script/check-spelling docs/
|
||||
|
||||
2
.github/workflows/publish_extension_cli.yml
vendored
2
.github/workflows/publish_extension_cli.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
- ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/randomized_tests.yml
vendored
2
.github/workflows/randomized_tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/release_actions.yml
vendored
2
.github/workflows/release_actions.yml
vendored
@@ -1,5 +1,3 @@
|
||||
name: Release Actions
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
12
.github/workflows/release_nightly.yml
vendored
12
.github/workflows/release_nightly.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -182,7 +182,7 @@ jobs:
|
||||
- bundle-linux-arm
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
name: Update All Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
@@ -10,16 +8,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@f3bcaebff5eace81a1c062af9f9011aae482ca9d # v3
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
name: Update Weekly Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 15 * * *"
|
||||
@@ -10,16 +8,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@f3bcaebff5eace81a1c062af9f9011aae482ca9d # v3
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
|
||||
797
Cargo.lock
generated
797
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
24
Cargo.toml
24
Cargo.toml
@@ -52,6 +52,7 @@ members = [
|
||||
"crates/indexed_docs",
|
||||
"crates/inline_completion_button",
|
||||
"crates/install_cli",
|
||||
"crates/isahc_http_client",
|
||||
"crates/journal",
|
||||
"crates/language",
|
||||
"crates/language_model",
|
||||
@@ -87,7 +88,6 @@ members = [
|
||||
"crates/remote",
|
||||
"crates/remote_server",
|
||||
"crates/repl",
|
||||
"crates/reqwest_client",
|
||||
"crates/rich_text",
|
||||
"crates/rope",
|
||||
"crates/rpc",
|
||||
@@ -122,7 +122,6 @@ members = [
|
||||
"crates/ui",
|
||||
"crates/ui_input",
|
||||
"crates/ui_macros",
|
||||
"crates/reqwest_client",
|
||||
"crates/util",
|
||||
"crates/vcs_menu",
|
||||
"crates/vim",
|
||||
@@ -145,6 +144,7 @@ members = [
|
||||
"extensions/elm",
|
||||
"extensions/emmet",
|
||||
"extensions/erlang",
|
||||
"extensions/gleam",
|
||||
"extensions/glsl",
|
||||
"extensions/haskell",
|
||||
"extensions/html",
|
||||
@@ -156,6 +156,7 @@ members = [
|
||||
"extensions/proto",
|
||||
"extensions/purescript",
|
||||
"extensions/ruff",
|
||||
"extensions/ruby",
|
||||
"extensions/slash-commands-example",
|
||||
"extensions/snippets",
|
||||
"extensions/svelte",
|
||||
@@ -219,7 +220,7 @@ git = { path = "crates/git" }
|
||||
git_hosting_providers = { path = "crates/git_hosting_providers" }
|
||||
go_to_line = { path = "crates/go_to_line" }
|
||||
google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui", default-features = false, features = ["http_client"]}
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
headless = { path = "crates/headless" }
|
||||
html_to_markdown = { path = "crates/html_to_markdown" }
|
||||
@@ -228,6 +229,7 @@ image_viewer = { path = "crates/image_viewer" }
|
||||
indexed_docs = { path = "crates/indexed_docs" }
|
||||
inline_completion_button = { path = "crates/inline_completion_button" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
isahc_http_client = { path = "crates/isahc_http_client" }
|
||||
journal = { path = "crates/journal" }
|
||||
language = { path = "crates/language" }
|
||||
language_model = { path = "crates/language_model" }
|
||||
@@ -264,7 +266,6 @@ release_channel = { path = "crates/release_channel" }
|
||||
remote = { path = "crates/remote" }
|
||||
remote_server = { path = "crates/remote_server" }
|
||||
repl = { path = "crates/repl" }
|
||||
reqwest_client = { path = "crates/reqwest_client" }
|
||||
rich_text = { path = "crates/rich_text" }
|
||||
rope = { path = "crates/rope" }
|
||||
rpc = { path = "crates/rpc" }
|
||||
@@ -326,7 +327,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.5.0"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.24"
|
||||
async-tungstenite = "0.23"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
base64 = "0.22"
|
||||
@@ -335,7 +336,6 @@ blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.18"
|
||||
cargo_toml = "0.20"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
@@ -367,6 +367,10 @@ ignore = "0.4.22"
|
||||
image = "0.25.1"
|
||||
indexmap = { version = "1.6.2", features = ["serde"] }
|
||||
indoc = "2"
|
||||
# We explicitly disable http2 support in isahc.
|
||||
isahc = { version = "1.7.2", default-features = false, features = [
|
||||
"text-decoding",
|
||||
] }
|
||||
itertools = "0.13.0"
|
||||
jsonwebtoken = "9.3"
|
||||
libc = "0.2"
|
||||
@@ -391,7 +395,6 @@ pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29", default-features = false, features = ["charset", "http2", "macos-system-configuration", "rustls-tls-native-roots", "stream"]}
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { version = "0.15", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
@@ -436,7 +439,7 @@ time = { version = "0.3", features = [
|
||||
] }
|
||||
tiny_http = "0.8"
|
||||
toml = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = "0.4.4"
|
||||
tree-sitter = { version = "0.23", features = ["wasm"] }
|
||||
tree-sitter-bash = "0.23"
|
||||
@@ -449,7 +452,6 @@ tree-sitter-go = "0.23"
|
||||
tree-sitter-go-mod = { git = "https://github.com/zed-industries/tree-sitter-go-mod", rev = "a9aea5e358cde4d0f8ff20b7bc4fa311e359c7ca", package = "tree-sitter-gomod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/zed-industries/tree-sitter-go-work", rev = "acb0617bf7f4fda02c6217676cc64acb89536dc7" }
|
||||
tree-sitter-heex = { git = "https://github.com/zed-industries/tree-sitter-heex", rev = "1dd45142fbb05562e35b2040c6129c9bca346592" }
|
||||
tree-sitter-diff = "0.1.0"
|
||||
tree-sitter-html = "0.20"
|
||||
tree-sitter-jsdoc = "0.23"
|
||||
tree-sitter-json = "0.23"
|
||||
@@ -477,11 +479,9 @@ wasmtime = { version = "24", default-features = false, features = [
|
||||
wasmtime-wasi = "24"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
version = "0.39"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[build]
|
||||
dockerfile = "Dockerfile-cross"
|
||||
@@ -13,9 +13,30 @@ ARG GITHUB_SHA
|
||||
|
||||
ENV GITHUB_SHA=$GITHUB_SHA
|
||||
|
||||
# At some point in the past 3 weeks, additional dependencies on `xkbcommon` and
|
||||
# `xkbcommon-x11` were introduced into collab.
|
||||
#
|
||||
# A `git bisect` points to this commit as being the culprit: `b8e6098f60e5dabe98fe8281f993858dacc04a55`.
|
||||
#
|
||||
# Now when we try to build collab for the Docker image, it fails with the following
|
||||
# error:
|
||||
#
|
||||
# ```
|
||||
# 985.3 = note: /usr/bin/ld: cannot find -lxkbcommon: No such file or directory
|
||||
# 985.3 /usr/bin/ld: cannot find -lxkbcommon-x11: No such file or directory
|
||||
# 985.3 collect2: error: ld returned 1 exit status
|
||||
# ```
|
||||
#
|
||||
# The last successful deploys were at:
|
||||
# - Staging: `4f408ec65a3867278322a189b4eb20f1ab51f508`
|
||||
# - Production: `fc4c533d0a8c489e5636a4249d2b52a80039fbd7`
|
||||
#
|
||||
# Also add `cmake`, since we need it to build `wasmtime`.
|
||||
#
|
||||
# Installing these as a temporary workaround, but I think ideally we'd want to figure
|
||||
# out what caused them to be included in the first place.
|
||||
RUN apt-get update; \
|
||||
apt-get install -y --no-install-recommends cmake
|
||||
apt-get install -y --no-install-recommends libxkbcommon-dev libxkbcommon-x11-dev cmake
|
||||
|
||||
RUN --mount=type=cache,target=./script/node_modules \
|
||||
--mount=type=cache,target=/usr/local/cargo/registry \
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG CROSS_BASE_IMAGE
|
||||
FROM ${CROSS_BASE_IMAGE}
|
||||
WORKDIR /app
|
||||
ARG TZ=Etc/UTC \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV CARGO_TERM_COLOR=always
|
||||
|
||||
COPY script/install-mold script/
|
||||
RUN ./script/install-mold "2.34.0"
|
||||
COPY script/remote-server script/
|
||||
RUN ./script/remote-server
|
||||
|
||||
COPY . .
|
||||
@@ -1,16 +0,0 @@
|
||||
.git
|
||||
.github
|
||||
**/.gitignore
|
||||
**/.gitkeep
|
||||
.gitattributes
|
||||
.mailmap
|
||||
**/target
|
||||
zed.xcworkspace
|
||||
.DS_Store
|
||||
compose.yml
|
||||
plugins/bin
|
||||
script/node_modules
|
||||
styles/node_modules
|
||||
crates/collab/static/styles.css
|
||||
vendor/bin
|
||||
assets/themes/
|
||||
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-diff"><path d="M12 3v14"/><path d="M5 10h14"/><path d="M5 21h14"/></svg>
|
||||
|
Before Width: | Height: | Size: 275 B |
@@ -20,7 +20,6 @@
|
||||
"bashrc": "terminal",
|
||||
"bmp": "image",
|
||||
"c": "c",
|
||||
"c++": "cpp",
|
||||
"cc": "cpp",
|
||||
"cjs": "javascript",
|
||||
"coffee": "coffeescript",
|
||||
@@ -28,7 +27,6 @@
|
||||
"cpp": "cpp",
|
||||
"css": "css",
|
||||
"csv": "storage",
|
||||
"cxx": "cpp",
|
||||
"cts": "typescript",
|
||||
"dart": "dart",
|
||||
"dat": "storage",
|
||||
@@ -68,13 +66,11 @@
|
||||
"heex": "elixir",
|
||||
"heic": "image",
|
||||
"heif": "image",
|
||||
"hh": "cpp",
|
||||
"hpp": "cpp",
|
||||
"hrl": "erlang",
|
||||
"hs": "haskell",
|
||||
"htm": "template",
|
||||
"html": "template",
|
||||
"hxx": "cpp",
|
||||
"ib": "storage",
|
||||
"ico": "image",
|
||||
"ini": "settings",
|
||||
|
||||
@@ -664,8 +664,7 @@
|
||||
"shift-up": "terminal::ScrollLineUp",
|
||||
"shift-down": "terminal::ScrollLineDown",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -395,7 +395,6 @@
|
||||
// Change the default action on `menu::Confirm` by setting the parameter
|
||||
// "alt-cmd-o": ["projects::OpenRecent", {"create_new_window": true }],
|
||||
"alt-cmd-o": "projects::OpenRecent",
|
||||
"ctrl-cmd-o": "projects::OpenRemote",
|
||||
"alt-cmd-b": "branches::OpenRecent",
|
||||
"ctrl-~": "workspace::NewTerminal",
|
||||
"cmd-s": "workspace::Save",
|
||||
@@ -679,8 +678,7 @@
|
||||
"cmd-home": "terminal::ScrollToTop",
|
||||
"cmd-end": "terminal::ScrollToBottom",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -128,10 +128,6 @@
|
||||
"shift-m": "vim::WindowMiddle",
|
||||
"shift-l": "vim::WindowBottom",
|
||||
// z commands
|
||||
"z enter": ["workspace::SendKeystrokes", "z t ^"],
|
||||
"z -": ["workspace::SendKeystrokes", "z b ^"],
|
||||
"z ^": ["workspace::SendKeystrokes", "shift-h k z b ^"],
|
||||
"z +": ["workspace::SendKeystrokes", "shift-l j z t ^"],
|
||||
"z t": "editor::ScrollCursorTop",
|
||||
"z z": "editor::ScrollCursorCenter",
|
||||
"z .": ["workspace::SendKeystrokes", "z z ^"],
|
||||
@@ -256,7 +252,6 @@
|
||||
"@": ["vim::PushOperator", "ReplayRegister"],
|
||||
"ctrl-pagedown": "pane::ActivateNextItem",
|
||||
"ctrl-pageup": "pane::ActivatePrevItem",
|
||||
"insert": "vim::InsertBefore",
|
||||
// tree-sitter related commands
|
||||
"[ x": "editor::SelectLargerSyntaxNode",
|
||||
"] x": "editor::SelectSmallerSyntaxNode",
|
||||
@@ -339,8 +334,7 @@
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"ctrl-r": ["vim::PushOperator", "Register"],
|
||||
"insert": "vim::ToggleReplace"
|
||||
"ctrl-r": ["vim::PushOperator", "Register"]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -359,8 +353,7 @@
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"backspace": "vim::UndoReplace",
|
||||
"tab": "vim::Tab",
|
||||
"enter": "vim::Enter",
|
||||
"insert": "vim::InsertBefore"
|
||||
"enter": "vim::Enter"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,33 +1,85 @@
|
||||
<task_description>
|
||||
|
||||
The user of a code editor wants to make a change to their codebase.
|
||||
You must describe the change using the following XML structure:
|
||||
# Code Change Workflow
|
||||
|
||||
- <patch> - A group of related code changes.
|
||||
Child tags:
|
||||
- <title> (required) - A high-level description of the changes. This should be as short
|
||||
as possible, possibly using common abbreviations.
|
||||
- <edit> (1 or more) - An edit to make at a particular range within a file.
|
||||
Includes the following child tags:
|
||||
- <path> (required) - The path to the file that will be changed.
|
||||
- <description> (optional) - An arbitrarily-long comment that describes the purpose
|
||||
of this edit.
|
||||
- <old_text> (optional) - An excerpt from the file's current contents that uniquely
|
||||
identifies a range within the file where the edit should occur. If this tag is not
|
||||
specified, then the entire file will be used as the range.
|
||||
- <new_text> (required) - The new text to insert into the file.
|
||||
- <operation> (required) - The type of change that should occur at the given range
|
||||
of the file. Must be one of the following values:
|
||||
- `update`: Replaces the entire range with the new text.
|
||||
- `insert_before`: Inserts the new text before the range.
|
||||
- `insert_after`: Inserts new text after the range.
|
||||
- `create`: Creates a new file with the given path and the new text.
|
||||
- `delete`: Deletes the specified range from the file.
|
||||
Your task is to guide the user through code changes using a series of steps. Each step should describe a high-level change, which can consist of multiple edits to distinct locations in the codebase.
|
||||
|
||||
## Output Example
|
||||
|
||||
Provide output as XML, with the following format:
|
||||
|
||||
<step>
|
||||
Update the Person struct to store an age
|
||||
|
||||
```rust
|
||||
struct Person {
|
||||
// existing fields...
|
||||
age: u8,
|
||||
height: f32,
|
||||
// existing fields...
|
||||
}
|
||||
|
||||
impl Person {
|
||||
fn age(&self) -> u8 {
|
||||
self.age
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<path>src/person.rs</path>
|
||||
<operation>insert_before</operation>
|
||||
<search>height: f32,</search>
|
||||
<description>Add the age field</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/person.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<search>impl Person {</search>
|
||||
<description>Add the age getter</description>
|
||||
</edit>
|
||||
</step>
|
||||
|
||||
## Output Format
|
||||
|
||||
First, each `<step>` must contain a written description of the change that should be made. The description should begin with a high-level overview, and can contain markdown code blocks as well. The description should be self-contained and actionable.
|
||||
|
||||
After the description, each `<step>` must contain one or more `<edit>` tags, each of which refer to a specific range in a source file. Each `<edit>` tag must contain the following child tags:
|
||||
|
||||
### `<path>` (required)
|
||||
|
||||
This tag contains the path to the file that will be changed. It can be an existing path, or a path that should be created.
|
||||
|
||||
### `<search>` (optional)
|
||||
|
||||
This tag contains a search string to locate in the source file, e.g. `pub fn baz() {`. If not provided, the new content will be inserted at the top of the file. Make sure to produce a string that exists in the source file and that isn't ambiguous. When there's ambiguity, add more lines to the search to eliminate it.
|
||||
|
||||
### `<description>` (required)
|
||||
|
||||
This tag contains a single-line description of the edit that should be made at the given location.
|
||||
|
||||
### `<operation>` (required)
|
||||
|
||||
This tag indicates what type of change should be made, relative to the given location. It can be one of the following:
|
||||
- `update`: Rewrites the specified string entirely based on the given description.
|
||||
- `create`: Creates a new file with the given path based on the provided description.
|
||||
- `insert_before`: Inserts new text based on the given description before the specified search string.
|
||||
- `insert_after`: Inserts new text based on the given description after the specified search string.
|
||||
- `delete`: Deletes the specified string from the containing file.
|
||||
|
||||
<guidelines>
|
||||
- Never provide multiple edits whose ranges intersect each other. Instead, merge them into one edit.
|
||||
- Prefer multiple edits to smaller, disjoint ranges, rather than one edit to a larger range.
|
||||
- There's no need to escape angle brackets within XML tags.
|
||||
- There's no need to describe *what* to do, just *where* to do it.
|
||||
- Only reference locations that actually exist (unless you're creating a file).
|
||||
- If creating a file, assume any subsequent updates are included at the time of creation.
|
||||
- Don't create and then update a file. Always create new files in one hot.
|
||||
- Prefer multiple edits to smaller regions, as opposed to one big edit to a larger region.
|
||||
- Don't produce edits that intersect each other. In that case, merge them into a bigger edit.
|
||||
- Never nest an edit with another edit. Never include CDATA. All edits are leaf nodes.
|
||||
- Descriptions are required for all edits except delete.
|
||||
- When generating multiple edits, ensure the descriptions are specific to each individual operation.
|
||||
- Avoid referring to the search string in the description. Focus on the change to be made, not the location where it's made. That's implicit with the `search` string you provide.
|
||||
- Don't generate multiple edits at the same location. Instead, combine them together in a single edit with a succinct combined description.
|
||||
- Always ensure imports are added if you're referencing symbols that are not in scope.
|
||||
</guidelines>
|
||||
|
||||
@@ -72,137 +124,189 @@ Update all shapes to store their origin as an (x, y) tuple and implement Display
|
||||
<message role="assistant">
|
||||
We'll need to update both the rectangle and circle modules.
|
||||
|
||||
<patch>
|
||||
<title>Add origins and display impls to shapes</title>
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<description>Add the origin field to Rectangle struct</description>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>
|
||||
pub struct Rectangle {
|
||||
</old_text>
|
||||
<new_text>
|
||||
origin: (f64, f64),
|
||||
</new_text>
|
||||
</edit>
|
||||
<step>
|
||||
Add origin fields to both shape types.
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<description>Update the Rectangle's new function to take an origin parameter</description>
|
||||
<operation>update</operation>
|
||||
<old_text>
|
||||
fn new(width: f64, height: f64) -> Self {
|
||||
Rectangle { width, height }
|
||||
}
|
||||
</old_text>
|
||||
<new_text>
|
||||
fn new(origin: (f64, f64), width: f64, height: f64) -> Self {
|
||||
Rectangle { origin, width, height }
|
||||
}
|
||||
</new_text>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/circle.rs</path>
|
||||
<description>Add the origin field to Circle struct</description>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>
|
||||
pub struct Circle {
|
||||
radius: f64,
|
||||
</old_text>
|
||||
<new_text>
|
||||
```rust
|
||||
struct Rectangle {
|
||||
// existing fields ...
|
||||
origin: (f64, f64),
|
||||
</new_text>
|
||||
}
|
||||
```
|
||||
|
||||
```rust
|
||||
struct Circle {
|
||||
// existing fields ...
|
||||
origin: (f64, f64),
|
||||
}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<operation>insert_before</operation>
|
||||
<search>
|
||||
width: f64,
|
||||
height: f64,
|
||||
</search>
|
||||
<description>Add the origin field to Rectangle</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/circle.rs</path>
|
||||
<description>Update the Circle's new function to take an origin parameter</description>
|
||||
<operation>insert_before</operation>
|
||||
<search>
|
||||
radius: f64,
|
||||
</search>
|
||||
<description>Add the origin field to Circle</description>
|
||||
</edit>
|
||||
|
||||
<step>
|
||||
Update both shape's constructors to take an origin.
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<operation>update</operation>
|
||||
<old_text>
|
||||
fn new(radius: f64) -> Self {
|
||||
Circle { radius }
|
||||
}
|
||||
</old_text>
|
||||
<new_text>
|
||||
fn new(origin: (f64, f64), radius: f64) -> Self {
|
||||
Circle { origin, radius }
|
||||
}
|
||||
</new_text>
|
||||
<search>
|
||||
fn new(width: f64, height: f64) -> Self {
|
||||
Rectangle { width, height }
|
||||
}
|
||||
</search>
|
||||
<description>Update the Rectangle new function to take an origin</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/circle.rs</path>
|
||||
<operation>update</operation>
|
||||
<search>
|
||||
fn new(radius: f64) -> Self {
|
||||
Circle { radius }
|
||||
}
|
||||
</search>
|
||||
<description>Update the Circle new function to take an origin</description>
|
||||
</edit>
|
||||
</step>
|
||||
|
||||
<step>
|
||||
Implement Display for both shapes
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<description>Add an import for the std::fmt module</description>
|
||||
<operation>insert_before</operation>
|
||||
<old_text>
|
||||
<search>
|
||||
struct Rectangle {
|
||||
</old_text>
|
||||
<new_text>
|
||||
use std::fmt;
|
||||
|
||||
</new_text>
|
||||
</search>
|
||||
<description>Add an import for the `std::fmt` module</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/rectangle.rs</path>
|
||||
<description>Add a Display implementation for Rectangle</description>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>
|
||||
<search>
|
||||
Rectangle { width, height }
|
||||
}
|
||||
}
|
||||
</old_text>
|
||||
<new_text>
|
||||
impl fmt::Display for Rectangle {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.format_struct(f, "Rectangle")
|
||||
.field("origin", &self.origin)
|
||||
.field("width", &self.width)
|
||||
.field("height", &self.height)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
</new_text>
|
||||
</search>
|
||||
<description>Add a Display implementation for Rectangle</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/circle.rs</path>
|
||||
<description>Add an import for the `std::fmt` module</description>
|
||||
<operation>insert_before</operation>
|
||||
<old_text>
|
||||
<search>
|
||||
struct Circle {
|
||||
</old_text>
|
||||
<new_text>
|
||||
use std::fmt;
|
||||
</new_text>
|
||||
</search>
|
||||
<description>Add an import for the `std::fmt` module</description>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/shapes/circle.rs</path>
|
||||
<description>Add a Display implementation for Circle</description>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>
|
||||
<search>
|
||||
Circle { radius }
|
||||
}
|
||||
}
|
||||
</old_text>
|
||||
<new_text>
|
||||
impl fmt::Display for Rectangle {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.format_struct(f, "Rectangle")
|
||||
.field("origin", &self.origin)
|
||||
.field("width", &self.width)
|
||||
.field("height", &self.height)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
</new_text>
|
||||
</search>
|
||||
<description>Add a Display implementation for Circle</description>
|
||||
</edit>
|
||||
</patch>
|
||||
|
||||
</step>
|
||||
</message>
|
||||
</example>
|
||||
|
||||
<example number=2>
|
||||
<message role="user">
|
||||
|
||||
```rs src/user.rs
|
||||
struct User {
|
||||
pub name: String,
|
||||
age: u32,
|
||||
email: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn new(name: String, age: u32, email: String) -> Self {
|
||||
User { name, age, email }
|
||||
}
|
||||
|
||||
pub fn print_info(&self) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Let's print all the user information and delete the email field.
|
||||
</message>
|
||||
|
||||
<message role="assistant">
|
||||
<step>
|
||||
Update the 'print_info' method to print user information
|
||||
|
||||
```rust
|
||||
impl User {
|
||||
// ... other methods ...
|
||||
|
||||
pub fn print_info(&self) {
|
||||
println!("Name: {name}, Age: {age}", name = self.name, age = self.age);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<path>src/user.rs</path>
|
||||
<operation>update</operation>
|
||||
<search>
|
||||
pub fn print_info(&self) {
|
||||
todo!()
|
||||
}
|
||||
</search>
|
||||
<description>Print all the user information</description>
|
||||
</edit>
|
||||
</step>
|
||||
|
||||
<step>
|
||||
Remove the 'email' field from the User struct
|
||||
|
||||
<edit>
|
||||
<path>src/user.rs</path>
|
||||
<operation>delete</operation>
|
||||
<search>
|
||||
email: String,
|
||||
</search>
|
||||
</edit>
|
||||
|
||||
<edit>
|
||||
<path>src/user.rs</path>
|
||||
<operation>update</operation>
|
||||
<symbol>
|
||||
fn new(name: String, age: u32, email: String) -> Self {
|
||||
User { name, age, email }
|
||||
}
|
||||
</symbol>
|
||||
<description>Remove email parameter from new method</description>
|
||||
</edit>
|
||||
</step>
|
||||
</message>
|
||||
</example>
|
||||
|
||||
You should think step by step. When possible, produce smaller, coherent logical steps as opposed to one big step that combines lots of heterogeneous edits.
|
||||
|
||||
</task_description>
|
||||
|
||||
496
assets/prompts/step_resolution.hbs
Normal file
496
assets/prompts/step_resolution.hbs
Normal file
@@ -0,0 +1,496 @@
|
||||
<overview>
|
||||
Your task is to map a step from a workflow to locations in source code where code needs to be changed to fulfill that step.
|
||||
Given a workflow containing background context plus a series of <step> tags, you will resolve *one* of these step tags to resolve to one or more locations in the code.
|
||||
With each location, you will produce a brief, one-line description of the changes to be made.
|
||||
|
||||
<guidelines>
|
||||
- There's no need to describe *what* to do, just *where* to do it.
|
||||
- Only reference locations that actually exist (unless you're creating a file).
|
||||
- If creating a file, assume any subsequent updates are included at the time of creation.
|
||||
- Don't create and then update a file. Always create new files in shot.
|
||||
- Prefer updating symbols lower in the syntax tree if possible.
|
||||
- Never include suggestions on a parent symbol and one of its children in the same suggestions block.
|
||||
- Never nest an operation with another operation or include CDATA or other content. All suggestions are leaf nodes.
|
||||
- Descriptions are required for all suggestions except delete.
|
||||
- When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
|
||||
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
|
||||
- Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
|
||||
- To add imports respond with a suggestion where the `"symbol"` key is set to `"#imports"`
|
||||
</guidelines>
|
||||
</overview>
|
||||
|
||||
<examples>
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/rectangle.rs
|
||||
struct Rectangle {
|
||||
width: f64,
|
||||
height: f64,
|
||||
}
|
||||
|
||||
impl Rectangle {
|
||||
fn new(width: f64, height: f64) -> Self {
|
||||
Rectangle { width, height }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
We need to add methods to calculate the area and perimeter of the rectangle. Can you help with that?
|
||||
</message>
|
||||
<message role="assistant">
|
||||
Sure, I can help with that!
|
||||
|
||||
<step>Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct</step>
|
||||
<step>Implement the 'Display' trait for the Rectangle struct</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct
|
||||
</step_to_resolve>
|
||||
|
||||
<incorrect_output reason="NEVER append multiple children at the same location.">
|
||||
{
|
||||
"title": "Add Rectangle methods",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/shapes.rs",
|
||||
"symbol": "impl Rectangle",
|
||||
"description": "Add calculate_area method"
|
||||
},
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/shapes.rs",
|
||||
"symbol": "impl Rectangle",
|
||||
"description": "Add calculate_perimeter method"
|
||||
}
|
||||
]
|
||||
}
|
||||
</incorrect_output>
|
||||
|
||||
<correct_output>
|
||||
{
|
||||
"title": "Add Rectangle methods",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/shapes.rs",
|
||||
"symbol": "impl Rectangle",
|
||||
"description": "Add calculate area and perimeter methods"
|
||||
}
|
||||
]
|
||||
}
|
||||
</correct_output>
|
||||
|
||||
<step_to_resolve>
|
||||
Implement the 'Display' trait for the Rectangle struct
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Implement Display for Rectangle",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/shapes.rs",
|
||||
"symbol": "impl Rectangle",
|
||||
"description": "Implement Display trait for Rectangle"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/user.rs
|
||||
struct User {
|
||||
pub name: String,
|
||||
age: u32,
|
||||
email: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn new(name: String, age: u32, email: String) -> Self {
|
||||
User { name, age, email }
|
||||
}
|
||||
|
||||
pub fn print_info(&self) {
|
||||
println!("Name: {}, Age: {}, Email: {}", self.name, self.age, self.email);
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
Certainly!
|
||||
<step>Update the 'print_info' method to use formatted output</step>
|
||||
<step>Remove the 'email' field from the User struct</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Update the 'print_info' method to use formatted output
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Use formatted output",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/user.rs",
|
||||
"symbol": "impl User pub fn print_info",
|
||||
"description": "Use formatted output"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
|
||||
<step_to_resolve>
|
||||
Remove the 'email' field from the User struct
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Remove email field",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "Delete",
|
||||
"path": "src/user.rs",
|
||||
"symbol": "struct User email"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/vehicle.rs
|
||||
struct Vehicle {
|
||||
make: String,
|
||||
model: String,
|
||||
year: u32,
|
||||
}
|
||||
|
||||
impl Vehicle {
|
||||
fn new(make: String, model: String, year: u32) -> Self {
|
||||
Vehicle { make, model, year }
|
||||
}
|
||||
|
||||
fn print_year(&self) {
|
||||
println!("Year: {}", self.year);
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
<step>Add a 'use std::fmt;' statement at the beginning of the file</step>
|
||||
<step>Add a new method 'start_engine' in the Vehicle impl block</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Add a 'use std::fmt;' statement at the beginning of the file
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Add use std::fmt statement",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/vehicle.rs",
|
||||
"symbol": "#imports",
|
||||
"description": "Add 'use std::fmt' statement"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
|
||||
<step_to_resolve>
|
||||
Add a new method 'start_engine' in the Vehicle impl block
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Add start_engine method",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/vehicle.rs",
|
||||
"symbol": "impl Vehicle fn new",
|
||||
"description": "Add start_engine method"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/employee.rs
|
||||
struct Employee {
|
||||
name: String,
|
||||
position: String,
|
||||
salary: u32,
|
||||
department: String,
|
||||
}
|
||||
|
||||
impl Employee {
|
||||
fn new(name: String, position: String, salary: u32, department: String) -> Self {
|
||||
Employee { name, position, salary, department }
|
||||
}
|
||||
|
||||
fn print_details(&self) {
|
||||
println!("Name: {}, Position: {}, Salary: {}, Department: {}",
|
||||
self.name, self.position, self.salary, self.department);
|
||||
}
|
||||
|
||||
fn give_raise(&mut self, amount: u32) {
|
||||
self.salary += amount;
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
<step>Make salary an f32</step>
|
||||
<step>Remove the 'department' field and update the 'print_details' method</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Make salary an f32
|
||||
</step_to_resolve>
|
||||
|
||||
<incorrect_output reason="NEVER include suggestions on a parent symbol and one of its children in the same suggestions block.">
|
||||
{
|
||||
"title": "Change salary to f32",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
"symbol": "struct Employee",
|
||||
"description": "Change the type of salary to an f32"
|
||||
},
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
"symbol": "struct Employee salary",
|
||||
"description": "Change the type to an f32"
|
||||
}
|
||||
]
|
||||
}
|
||||
</incorrect_output>
|
||||
|
||||
<correct_output>
|
||||
{
|
||||
"title": "Change salary to f32",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
"symbol": "struct Employee salary",
|
||||
"description": "Change the type to an f32"
|
||||
}
|
||||
]
|
||||
}
|
||||
</correct_output>
|
||||
|
||||
<step_to_resolve>
|
||||
Remove the 'department' field and update the 'print_details' method
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Remove department",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "Delete",
|
||||
"path": "src/employee.rs",
|
||||
"symbol": "struct Employee department"
|
||||
},
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
"symbol": "impl Employee fn print_details",
|
||||
"description": "Don't print the 'department' field"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/game.rs
|
||||
struct Player {
|
||||
name: String,
|
||||
health: i32,
|
||||
pub score: u32,
|
||||
}
|
||||
|
||||
impl Player {
|
||||
pub fn new(name: String) -> Self {
|
||||
Player { name, health: 100, score: 0 }
|
||||
}
|
||||
}
|
||||
|
||||
struct Game {
|
||||
players: Vec<Player>,
|
||||
}
|
||||
|
||||
impl Game {
|
||||
fn new() -> Self {
|
||||
Game { players: Vec::new() }
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
<step>Add a 'level' field to Player and update the 'new' method</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Add a 'level' field to Player and update the 'new' method
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Add level field to Player",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/game.rs",
|
||||
"symbol": "struct Player pub score",
|
||||
"description": "Add level field to Player"
|
||||
},
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/game.rs",
|
||||
"symbol": "impl Player pub fn new",
|
||||
"description": "Initialize level in new method"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/config.rs
|
||||
use std::collections::HashMap;
|
||||
|
||||
struct Config {
|
||||
settings: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn new() -> Self {
|
||||
Config { settings: HashMap::new() }
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
<step>Add a 'load_from_file' method to Config and import necessary modules</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Add a 'load_from_file' method to Config and import necessary modules
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Add load_from_file method",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/config.rs",
|
||||
"symbol": "#imports",
|
||||
"description": "Import std::fs and std::io modules"
|
||||
},
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/config.rs",
|
||||
"symbol": "impl Config",
|
||||
"description": "Add load_from_file method"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
|
||||
<example>
|
||||
<workflow_context>
|
||||
<message role="user">
|
||||
```rs src/database.rs
|
||||
pub(crate) struct Database {
|
||||
connection: Connection,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
fn new(url: &str) -> Result<Self, Error> {
|
||||
let connection = Connection::connect(url)?;
|
||||
Ok(Database { connection })
|
||||
}
|
||||
|
||||
async fn query(&self, sql: &str) -> Result<Vec<Row>, Error> {
|
||||
self.connection.query(sql, &[])
|
||||
}
|
||||
}
|
||||
```
|
||||
</message>
|
||||
<message role="assistant">
|
||||
<step>Add error handling to the 'query' method and create a custom error type</step>
|
||||
</message>
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
Add error handling to the 'query' method and create a custom error type
|
||||
</step_to_resolve>
|
||||
|
||||
<output>
|
||||
{
|
||||
"title": "Add error handling to query",
|
||||
"suggestions": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/database.rs",
|
||||
"description": "Import necessary error handling modules"
|
||||
},
|
||||
{
|
||||
"kind": "InsertSiblingBefore",
|
||||
"path": "src/database.rs",
|
||||
"symbol": "pub(crate) struct Database",
|
||||
"description": "Define custom DatabaseError enum"
|
||||
},
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/database.rs",
|
||||
"symbol": "impl Database async fn query",
|
||||
"description": "Implement error handling in query method"
|
||||
}
|
||||
]
|
||||
}
|
||||
</output>
|
||||
</example>
|
||||
</examples>
|
||||
|
||||
Now generate the suggestions for the following step:
|
||||
|
||||
<workflow_context>
|
||||
{{{workflow_context}}}
|
||||
</workflow_context>
|
||||
|
||||
<step_to_resolve>
|
||||
{{{step_to_resolve}}}
|
||||
</step_to_resolve>
|
||||
@@ -118,8 +118,8 @@
|
||||
// "bar"
|
||||
// 2. A block that surrounds the following character
|
||||
// "block"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -494,14 +494,7 @@
|
||||
// Position of the close button on the editor tabs.
|
||||
"close_position": "right",
|
||||
// Whether to show the file icon for a tab.
|
||||
"file_icons": false,
|
||||
// What to do after closing the current tab.
|
||||
//
|
||||
// 1. Activate the tab that was open previously (default)
|
||||
// "History"
|
||||
// 2. Activate the neighbour tab (prefers the right one, if present)
|
||||
// "Neighbour"
|
||||
"activate_on_close": "history"
|
||||
"file_icons": false
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -691,8 +684,8 @@
|
||||
// "block"
|
||||
// 2. A vertical bar
|
||||
// "bar"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -712,10 +705,10 @@
|
||||
// May take 2 values:
|
||||
// 1. Rely on default platform handling of option key, on macOS
|
||||
// this means generating certain unicode characters
|
||||
// "option_as_meta": false,
|
||||
// "option_to_meta": false,
|
||||
// 2. Make the option keys behave as a 'meta' key, e.g. for emacs
|
||||
// "option_as_meta": true,
|
||||
"option_as_meta": false,
|
||||
// "option_to_meta": true,
|
||||
"option_as_meta": true,
|
||||
// Whether or not selecting text in the terminal will automatically
|
||||
// copy to the system clipboard.
|
||||
"copy_on_select": false,
|
||||
@@ -851,10 +844,6 @@
|
||||
"Dart": {
|
||||
"tab_size": 2
|
||||
},
|
||||
"Diff": {
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
// Server-specific settings
|
||||
//
|
||||
// For a full list of overridable settings, and general information on settings,
|
||||
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
|
||||
{
|
||||
"lsp": {}
|
||||
}
|
||||
@@ -101,7 +101,6 @@ impl ActivityIndicator {
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
buffer.set_capability(language::Capability::ReadOnly, cx);
|
||||
})?;
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
workspace.add_item_to_active_pane(
|
||||
|
||||
@@ -26,3 +26,6 @@ serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -97,7 +97,6 @@ language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
languages = { workspace = true, features = ["test-support"] }
|
||||
log.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
|
||||
@@ -6,7 +6,6 @@ mod context;
|
||||
pub mod context_store;
|
||||
mod inline_assistant;
|
||||
mod model_selector;
|
||||
mod patch;
|
||||
mod prompt_library;
|
||||
mod prompts;
|
||||
mod slash_command;
|
||||
@@ -15,6 +14,7 @@ pub mod slash_command_settings;
|
||||
mod streaming_diff;
|
||||
mod terminal_inline_assistant;
|
||||
mod tools;
|
||||
mod workflow;
|
||||
|
||||
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||
use assistant_settings::AssistantSettings;
|
||||
@@ -35,13 +35,11 @@ use language_model::{
|
||||
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||
};
|
||||
pub(crate) use model_selector::*;
|
||||
pub use patch::*;
|
||||
pub use prompts::PromptBuilder;
|
||||
use prompts::PromptLoadingParams;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticDb};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
use slash_command::workflow_command::WorkflowSlashCommand;
|
||||
use slash_command::{
|
||||
auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
|
||||
diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
|
||||
@@ -52,6 +50,7 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
pub(crate) use streaming_diff::*;
|
||||
use util::ResultExt;
|
||||
pub use workflow::*;
|
||||
|
||||
use crate::slash_command_settings::SlashCommandSettings;
|
||||
|
||||
@@ -394,25 +393,12 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
|
||||
slash_command_registry.register_command(now_command::NowSlashCommand, false);
|
||||
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
|
||||
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
|
||||
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
|
||||
|
||||
if let Some(prompt_builder) = prompt_builder {
|
||||
cx.observe_global::<SettingsStore>({
|
||||
let slash_command_registry = slash_command_registry.clone();
|
||||
let prompt_builder = prompt_builder.clone();
|
||||
move |cx| {
|
||||
if AssistantSettings::get_global(cx).are_live_diffs_enabled(cx) {
|
||||
slash_command_registry.register_command(
|
||||
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
slash_command_registry.unregister_command_by_name(WorkflowSlashCommand::NAME);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
slash_command_registry.register_command(
|
||||
workflow_command::WorkflowSlashCommand::new(prompt_builder.clone()),
|
||||
true,
|
||||
);
|
||||
cx.observe_flag::<project_command::ProjectSlashCommandFeatureFlag, _>({
|
||||
let slash_command_registry = slash_command_registry.clone();
|
||||
move |is_enabled, _cx| {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,6 @@ use std::sync::Arc;
|
||||
|
||||
use ::open_ai::Model as OpenAiModel;
|
||||
use anthropic::Model as AnthropicModel;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use fs::Fs;
|
||||
use gpui::{AppContext, Pixels};
|
||||
use language_model::provider::open_ai;
|
||||
@@ -62,13 +61,6 @@ pub struct AssistantSettings {
|
||||
pub default_model: LanguageModelSelection,
|
||||
pub inline_alternatives: Vec<LanguageModelSelection>,
|
||||
pub using_outdated_settings_version: bool,
|
||||
pub enable_experimental_live_diffs: bool,
|
||||
}
|
||||
|
||||
impl AssistantSettings {
|
||||
pub fn are_live_diffs_enabled(&self, cx: &AppContext) -> bool {
|
||||
cx.is_staff() || self.enable_experimental_live_diffs
|
||||
}
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
@@ -246,7 +238,6 @@ impl AssistantSettingsContent {
|
||||
}
|
||||
}),
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
|
||||
},
|
||||
@@ -266,7 +257,6 @@ impl AssistantSettingsContent {
|
||||
.to_string(),
|
||||
}),
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -383,7 +373,6 @@ impl Default for VersionedAssistantSettingsContent {
|
||||
default_height: None,
|
||||
default_model: None,
|
||||
inline_alternatives: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -414,10 +403,6 @@ pub struct AssistantSettingsContentV2 {
|
||||
default_model: Option<LanguageModelSelection>,
|
||||
/// Additional models with which to generate alternatives when performing inline assists.
|
||||
inline_alternatives: Option<Vec<LanguageModelSelection>>,
|
||||
/// Enable experimental live diffs in the assistant panel.
|
||||
///
|
||||
/// Default: false
|
||||
enable_experimental_live_diffs: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
@@ -540,10 +525,7 @@ impl Settings for AssistantSettings {
|
||||
);
|
||||
merge(&mut settings.default_model, value.default_model);
|
||||
merge(&mut settings.inline_alternatives, value.inline_alternatives);
|
||||
merge(
|
||||
&mut settings.enable_experimental_live_diffs,
|
||||
value.enable_experimental_live_diffs,
|
||||
);
|
||||
// merge(&mut settings.infer_context, value.infer_context); TODO re-enable this once we ship context inference
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
@@ -602,7 +584,6 @@ mod tests {
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
}),
|
||||
)
|
||||
},
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
mod context_tests;
|
||||
|
||||
use crate::{
|
||||
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantEdit, AssistantPatch,
|
||||
AssistantPatchStatus, MessageId, MessageStatus,
|
||||
prompts::PromptBuilder, slash_command::SlashCommandLine, MessageId, MessageStatus,
|
||||
WorkflowStep, WorkflowStepEdit, WorkflowStepResolution, WorkflowSuggestionGroup,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
@@ -15,10 +15,13 @@ use clock::ReplicaId;
|
||||
use collections::{HashMap, HashSet};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::{future::Shared, FutureExt, StreamExt};
|
||||
use futures::{
|
||||
future::{self, Shared},
|
||||
FutureExt, StreamExt,
|
||||
};
|
||||
use gpui::{
|
||||
AppContext, Context as _, EventEmitter, Model, ModelContext, RenderImage, SharedString,
|
||||
Subscription, Task,
|
||||
AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, RenderImage,
|
||||
SharedString, Subscription, Task,
|
||||
};
|
||||
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
@@ -35,7 +38,7 @@ use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
cmp::{max, Ordering},
|
||||
cmp::{self, max, Ordering},
|
||||
fmt::Debug,
|
||||
iter, mem,
|
||||
ops::Range,
|
||||
@@ -297,7 +300,7 @@ pub enum ContextEvent {
|
||||
MessagesEdited,
|
||||
SummaryChanged,
|
||||
StreamedCompletion,
|
||||
PatchesUpdated {
|
||||
WorkflowStepsUpdated {
|
||||
removed: Vec<Range<language::Anchor>>,
|
||||
updated: Vec<Range<language::Anchor>>,
|
||||
},
|
||||
@@ -451,14 +454,13 @@ pub struct XmlTag {
|
||||
#[derive(Copy, Clone, Debug, strum::EnumString, PartialEq, Eq, strum::AsRefStr)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum XmlTagKind {
|
||||
Patch,
|
||||
Title,
|
||||
Step,
|
||||
Edit,
|
||||
Path,
|
||||
Description,
|
||||
OldText,
|
||||
NewText,
|
||||
Search,
|
||||
Within,
|
||||
Operation,
|
||||
Description,
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
@@ -488,7 +490,7 @@ pub struct Context {
|
||||
_subscriptions: Vec<Subscription>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
patches: Vec<AssistantPatch>,
|
||||
workflow_steps: Vec<WorkflowStep>,
|
||||
xml_tags: Vec<XmlTag>,
|
||||
project: Option<Model<Project>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
@@ -504,7 +506,7 @@ impl ContextAnnotation for PendingSlashCommand {
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextAnnotation for AssistantPatch {
|
||||
impl ContextAnnotation for WorkflowStep {
|
||||
fn range(&self) -> &Range<language::Anchor> {
|
||||
&self.range
|
||||
}
|
||||
@@ -589,7 +591,7 @@ impl Context {
|
||||
telemetry,
|
||||
project,
|
||||
language_registry,
|
||||
patches: Vec::new(),
|
||||
workflow_steps: Vec::new(),
|
||||
xml_tags: Vec::new(),
|
||||
prompt_builder,
|
||||
};
|
||||
@@ -927,49 +929,48 @@ impl Context {
|
||||
self.summary.as_ref()
|
||||
}
|
||||
|
||||
pub(crate) fn patch_containing(
|
||||
pub(crate) fn workflow_step_containing(
|
||||
&self,
|
||||
position: Point,
|
||||
offset: usize,
|
||||
cx: &AppContext,
|
||||
) -> Option<&AssistantPatch> {
|
||||
) -> Option<&WorkflowStep> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let index = self.patches.binary_search_by(|patch| {
|
||||
let patch_range = patch.range.to_point(&buffer);
|
||||
if position < patch_range.start {
|
||||
Ordering::Greater
|
||||
} else if position > patch_range.end {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
});
|
||||
if let Ok(ix) = index {
|
||||
Some(&self.patches[ix])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let index = self
|
||||
.workflow_steps
|
||||
.binary_search_by(|step| {
|
||||
let step_range = step.range.to_offset(&buffer);
|
||||
if offset < step_range.start {
|
||||
Ordering::Greater
|
||||
} else if offset > step_range.end {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
})
|
||||
.ok()?;
|
||||
Some(&self.workflow_steps[index])
|
||||
}
|
||||
|
||||
pub fn patch_ranges(&self) -> impl Iterator<Item = Range<language::Anchor>> + '_ {
|
||||
self.patches.iter().map(|patch| patch.range.clone())
|
||||
pub fn workflow_step_ranges(&self) -> impl Iterator<Item = Range<language::Anchor>> + '_ {
|
||||
self.workflow_steps.iter().map(|step| step.range.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn patch_for_range(
|
||||
pub(crate) fn workflow_step_for_range(
|
||||
&self,
|
||||
range: &Range<language::Anchor>,
|
||||
cx: &AppContext,
|
||||
) -> Option<&AssistantPatch> {
|
||||
) -> Option<&WorkflowStep> {
|
||||
let buffer = self.buffer.read(cx);
|
||||
let index = self.patch_index_for_range(range, buffer).ok()?;
|
||||
Some(&self.patches[index])
|
||||
let index = self.workflow_step_index_for_range(range, buffer).ok()?;
|
||||
Some(&self.workflow_steps[index])
|
||||
}
|
||||
|
||||
fn patch_index_for_range(
|
||||
fn workflow_step_index_for_range(
|
||||
&self,
|
||||
tagged_range: &Range<text::Anchor>,
|
||||
buffer: &text::BufferSnapshot,
|
||||
) -> Result<usize, usize> {
|
||||
self.patches
|
||||
self.workflow_steps
|
||||
.binary_search_by(|probe| probe.range.cmp(&tagged_range, buffer))
|
||||
}
|
||||
|
||||
@@ -1017,6 +1018,8 @@ impl Context {
|
||||
language::BufferEvent::Edited => {
|
||||
self.count_remaining_tokens(cx);
|
||||
self.reparse(cx);
|
||||
// Use `inclusive = true` to invalidate a step when an edit occurs
|
||||
// at the start/end of a parsed step.
|
||||
cx.emit(ContextEvent::MessagesEdited);
|
||||
}
|
||||
_ => {}
|
||||
@@ -1245,8 +1248,8 @@ impl Context {
|
||||
|
||||
let mut removed_slash_command_ranges = Vec::new();
|
||||
let mut updated_slash_commands = Vec::new();
|
||||
let mut removed_patches = Vec::new();
|
||||
let mut updated_patches = Vec::new();
|
||||
let mut removed_steps = Vec::new();
|
||||
let mut updated_steps = Vec::new();
|
||||
while let Some(mut row_range) = row_ranges.next() {
|
||||
while let Some(next_row_range) = row_ranges.peek() {
|
||||
if row_range.end >= next_row_range.start {
|
||||
@@ -1270,11 +1273,11 @@ impl Context {
|
||||
&mut removed_slash_command_ranges,
|
||||
cx,
|
||||
);
|
||||
self.reparse_patches_in_range(
|
||||
self.reparse_workflow_steps_in_range(
|
||||
start..end,
|
||||
&buffer,
|
||||
&mut updated_patches,
|
||||
&mut removed_patches,
|
||||
&mut updated_steps,
|
||||
&mut removed_steps,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1286,10 +1289,10 @@ impl Context {
|
||||
});
|
||||
}
|
||||
|
||||
if !updated_patches.is_empty() || !removed_patches.is_empty() {
|
||||
cx.emit(ContextEvent::PatchesUpdated {
|
||||
removed: removed_patches,
|
||||
updated: updated_patches,
|
||||
if !updated_steps.is_empty() || !removed_steps.is_empty() {
|
||||
cx.emit(ContextEvent::WorkflowStepsUpdated {
|
||||
removed: removed_steps,
|
||||
updated: updated_steps,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1351,7 +1354,7 @@ impl Context {
|
||||
removed.extend(removed_commands.map(|command| command.source_range));
|
||||
}
|
||||
|
||||
fn reparse_patches_in_range(
|
||||
fn reparse_workflow_steps_in_range(
|
||||
&mut self,
|
||||
range: Range<text::Anchor>,
|
||||
buffer: &BufferSnapshot,
|
||||
@@ -1366,32 +1369,41 @@ impl Context {
|
||||
self.xml_tags
|
||||
.splice(intersecting_tags_range.clone(), new_tags);
|
||||
|
||||
// Find which patches intersect the changed range.
|
||||
let intersecting_patches_range =
|
||||
self.indices_intersecting_buffer_range(&self.patches, range.clone(), cx);
|
||||
// Find which steps intersect the changed range.
|
||||
let intersecting_steps_range =
|
||||
self.indices_intersecting_buffer_range(&self.workflow_steps, range.clone(), cx);
|
||||
|
||||
// Reparse all tags after the last unchanged patch before the change.
|
||||
// Reparse all tags after the last unchanged step before the change.
|
||||
let mut tags_start_ix = 0;
|
||||
if let Some(preceding_unchanged_patch) =
|
||||
self.patches[..intersecting_patches_range.start].last()
|
||||
if let Some(preceding_unchanged_step) =
|
||||
self.workflow_steps[..intersecting_steps_range.start].last()
|
||||
{
|
||||
tags_start_ix = match self.xml_tags.binary_search_by(|tag| {
|
||||
tag.range
|
||||
.start
|
||||
.cmp(&preceding_unchanged_patch.range.end, buffer)
|
||||
.cmp(&preceding_unchanged_step.range.end, buffer)
|
||||
.then(Ordering::Less)
|
||||
}) {
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
}
|
||||
|
||||
// Rebuild the patches in the range.
|
||||
let new_patches = self.parse_patches(tags_start_ix, range.end, buffer, cx);
|
||||
updated.extend(new_patches.iter().map(|patch| patch.range.clone()));
|
||||
let removed_patches = self.patches.splice(intersecting_patches_range, new_patches);
|
||||
// Rebuild the edit suggestions in the range.
|
||||
let mut new_steps = self.parse_steps(tags_start_ix, range.end, buffer);
|
||||
|
||||
if let Some(project) = self.project() {
|
||||
for step in &mut new_steps {
|
||||
Self::resolve_workflow_step_internal(step, &project, cx);
|
||||
}
|
||||
}
|
||||
|
||||
updated.extend(new_steps.iter().map(|step| step.range.clone()));
|
||||
let removed_steps = self
|
||||
.workflow_steps
|
||||
.splice(intersecting_steps_range, new_steps);
|
||||
removed.extend(
|
||||
removed_patches
|
||||
.map(|patch| patch.range)
|
||||
removed_steps
|
||||
.map(|step| step.range)
|
||||
.filter(|range| !updated.contains(&range)),
|
||||
);
|
||||
}
|
||||
@@ -1452,95 +1464,60 @@ impl Context {
|
||||
tags
|
||||
}
|
||||
|
||||
fn parse_patches(
|
||||
fn parse_steps(
|
||||
&mut self,
|
||||
tags_start_ix: usize,
|
||||
buffer_end: text::Anchor,
|
||||
buffer: &BufferSnapshot,
|
||||
cx: &AppContext,
|
||||
) -> Vec<AssistantPatch> {
|
||||
let mut new_patches = Vec::new();
|
||||
let mut pending_patch = None;
|
||||
let mut patch_tag_depth = 0;
|
||||
) -> Vec<WorkflowStep> {
|
||||
let mut new_steps = Vec::new();
|
||||
let mut pending_step = None;
|
||||
let mut edit_step_depth = 0;
|
||||
let mut tags = self.xml_tags[tags_start_ix..].iter().peekable();
|
||||
'tags: while let Some(tag) = tags.next() {
|
||||
if tag.range.start.cmp(&buffer_end, buffer).is_gt() && patch_tag_depth == 0 {
|
||||
if tag.range.start.cmp(&buffer_end, buffer).is_gt() && edit_step_depth == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if tag.kind == XmlTagKind::Patch && tag.is_open_tag {
|
||||
patch_tag_depth += 1;
|
||||
let patch_start = tag.range.start;
|
||||
let mut edits = Vec::<Result<AssistantEdit>>::new();
|
||||
let mut patch = AssistantPatch {
|
||||
range: patch_start..patch_start,
|
||||
title: String::new().into(),
|
||||
if tag.kind == XmlTagKind::Step && tag.is_open_tag {
|
||||
edit_step_depth += 1;
|
||||
let edit_start = tag.range.start;
|
||||
let mut edits = Vec::new();
|
||||
let mut step = WorkflowStep {
|
||||
range: edit_start..edit_start,
|
||||
leading_tags_end: tag.range.end,
|
||||
trailing_tag_start: None,
|
||||
edits: Default::default(),
|
||||
status: crate::AssistantPatchStatus::Pending,
|
||||
resolution: None,
|
||||
resolution_task: None,
|
||||
};
|
||||
|
||||
while let Some(tag) = tags.next() {
|
||||
if tag.kind == XmlTagKind::Patch && !tag.is_open_tag {
|
||||
patch_tag_depth -= 1;
|
||||
if patch_tag_depth == 0 {
|
||||
patch.range.end = tag.range.end;
|
||||
step.trailing_tag_start.get_or_insert(tag.range.start);
|
||||
|
||||
// Include the line immediately after this <patch> tag if it's empty.
|
||||
let patch_end_offset = patch.range.end.to_offset(buffer);
|
||||
let mut patch_end_chars = buffer.chars_at(patch_end_offset);
|
||||
if patch_end_chars.next() == Some('\n')
|
||||
&& patch_end_chars.next().map_or(true, |ch| ch == '\n')
|
||||
{
|
||||
let messages = self.messages_for_offsets(
|
||||
[patch_end_offset, patch_end_offset + 1],
|
||||
cx,
|
||||
);
|
||||
if messages.len() == 1 {
|
||||
patch.range.end = buffer.anchor_before(patch_end_offset + 1);
|
||||
}
|
||||
}
|
||||
|
||||
edits.sort_unstable_by(|a, b| {
|
||||
if let (Ok(a), Ok(b)) = (a, b) {
|
||||
a.path.cmp(&b.path)
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
});
|
||||
patch.edits = edits.into();
|
||||
patch.status = AssistantPatchStatus::Ready;
|
||||
new_patches.push(patch);
|
||||
if tag.kind == XmlTagKind::Step && !tag.is_open_tag {
|
||||
// step.trailing_tag_start = Some(tag.range.start);
|
||||
edit_step_depth -= 1;
|
||||
if edit_step_depth == 0 {
|
||||
step.range.end = tag.range.end;
|
||||
step.edits = edits.into();
|
||||
new_steps.push(step);
|
||||
continue 'tags;
|
||||
}
|
||||
}
|
||||
|
||||
if tag.kind == XmlTagKind::Title && tag.is_open_tag {
|
||||
let content_start = tag.range.end;
|
||||
while let Some(tag) = tags.next() {
|
||||
if tag.kind == XmlTagKind::Title && !tag.is_open_tag {
|
||||
let content_end = tag.range.start;
|
||||
patch.title =
|
||||
trimmed_text_in_range(buffer, content_start..content_end)
|
||||
.into();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tag.kind == XmlTagKind::Edit && tag.is_open_tag {
|
||||
let mut path = None;
|
||||
let mut old_text = None;
|
||||
let mut new_text = None;
|
||||
let mut search = None;
|
||||
let mut operation = None;
|
||||
let mut description = None;
|
||||
|
||||
while let Some(tag) = tags.next() {
|
||||
if tag.kind == XmlTagKind::Edit && !tag.is_open_tag {
|
||||
edits.push(AssistantEdit::new(
|
||||
edits.push(WorkflowStepEdit::new(
|
||||
path,
|
||||
operation,
|
||||
old_text,
|
||||
new_text,
|
||||
search,
|
||||
description,
|
||||
));
|
||||
break;
|
||||
@@ -1549,8 +1526,7 @@ impl Context {
|
||||
if tag.is_open_tag
|
||||
&& [
|
||||
XmlTagKind::Path,
|
||||
XmlTagKind::OldText,
|
||||
XmlTagKind::NewText,
|
||||
XmlTagKind::Search,
|
||||
XmlTagKind::Operation,
|
||||
XmlTagKind::Description,
|
||||
]
|
||||
@@ -1562,18 +1538,15 @@ impl Context {
|
||||
if tag.kind == kind && !tag.is_open_tag {
|
||||
let tag = tags.next().unwrap();
|
||||
let content_end = tag.range.start;
|
||||
let content = trimmed_text_in_range(
|
||||
buffer,
|
||||
content_start..content_end,
|
||||
);
|
||||
let mut content = buffer
|
||||
.text_for_range(content_start..content_end)
|
||||
.collect::<String>();
|
||||
content.truncate(content.trim_end().len());
|
||||
match kind {
|
||||
XmlTagKind::Path => path = Some(content),
|
||||
XmlTagKind::Operation => operation = Some(content),
|
||||
XmlTagKind::OldText => {
|
||||
old_text = Some(content).filter(|s| !s.is_empty())
|
||||
}
|
||||
XmlTagKind::NewText => {
|
||||
new_text = Some(content).filter(|s| !s.is_empty())
|
||||
XmlTagKind::Search => {
|
||||
search = Some(content).filter(|s| !s.is_empty())
|
||||
}
|
||||
XmlTagKind::Description => {
|
||||
description =
|
||||
@@ -1588,28 +1561,162 @@ impl Context {
|
||||
}
|
||||
}
|
||||
|
||||
patch.edits = edits.into();
|
||||
pending_patch = Some(patch);
|
||||
pending_step = Some(step);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(mut pending_patch) = pending_patch {
|
||||
let patch_start = pending_patch.range.start.to_offset(buffer);
|
||||
if let Some(message) = self.message_for_offset(patch_start, cx) {
|
||||
if message.anchor_range.end == text::Anchor::MAX {
|
||||
pending_patch.range.end = text::Anchor::MAX;
|
||||
if let Some(mut pending_step) = pending_step {
|
||||
pending_step.range.end = text::Anchor::MAX;
|
||||
new_steps.push(pending_step);
|
||||
}
|
||||
|
||||
new_steps
|
||||
}
|
||||
|
||||
pub fn resolve_workflow_step(
|
||||
&mut self,
|
||||
tagged_range: Range<text::Anchor>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Option<()> {
|
||||
let index = self
|
||||
.workflow_step_index_for_range(&tagged_range, self.buffer.read(cx))
|
||||
.ok()?;
|
||||
let step = &mut self.workflow_steps[index];
|
||||
let project = self.project.as_ref()?;
|
||||
step.resolution.take();
|
||||
Self::resolve_workflow_step_internal(step, project, cx);
|
||||
None
|
||||
}
|
||||
|
||||
fn resolve_workflow_step_internal(
|
||||
step: &mut WorkflowStep,
|
||||
project: &Model<Project>,
|
||||
cx: &mut ModelContext<'_, Context>,
|
||||
) {
|
||||
step.resolution_task = Some(cx.spawn({
|
||||
let range = step.range.clone();
|
||||
let edits = step.edits.clone();
|
||||
let project = project.clone();
|
||||
|this, mut cx| async move {
|
||||
let suggestion_groups =
|
||||
Self::compute_step_resolution(project, edits, &mut cx).await;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let buffer = this.buffer.read(cx).text_snapshot();
|
||||
let ix = this.workflow_step_index_for_range(&range, &buffer).ok();
|
||||
if let Some(ix) = ix {
|
||||
let step = &mut this.workflow_steps[ix];
|
||||
|
||||
let resolution = suggestion_groups.map(|suggestion_groups| {
|
||||
let mut title = String::new();
|
||||
for mut chunk in buffer.text_for_range(
|
||||
step.leading_tags_end
|
||||
..step.trailing_tag_start.unwrap_or(step.range.end),
|
||||
) {
|
||||
if title.is_empty() {
|
||||
chunk = chunk.trim_start();
|
||||
}
|
||||
if let Some((prefix, _)) = chunk.split_once('\n') {
|
||||
title.push_str(prefix);
|
||||
break;
|
||||
} else {
|
||||
title.push_str(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
WorkflowStepResolution {
|
||||
title,
|
||||
suggestion_groups,
|
||||
}
|
||||
});
|
||||
|
||||
step.resolution = Some(Arc::new(resolution));
|
||||
cx.emit(ContextEvent::WorkflowStepsUpdated {
|
||||
removed: vec![],
|
||||
updated: vec![range],
|
||||
})
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
async fn compute_step_resolution(
|
||||
project: Model<Project>,
|
||||
edits: Arc<[Result<WorkflowStepEdit>]>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>> {
|
||||
let mut suggestion_tasks = Vec::new();
|
||||
for edit in edits.iter() {
|
||||
let edit = edit.as_ref().map_err(|e| anyhow!("{e}"))?;
|
||||
suggestion_tasks.push(edit.resolve(project.clone(), cx.clone()));
|
||||
}
|
||||
|
||||
// Expand the context ranges of each suggestion and group suggestions with overlapping context ranges.
|
||||
let suggestions = future::try_join_all(suggestion_tasks).await?;
|
||||
|
||||
let mut suggestions_by_buffer = HashMap::default();
|
||||
for (buffer, suggestion) in suggestions {
|
||||
suggestions_by_buffer
|
||||
.entry(buffer)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(suggestion);
|
||||
}
|
||||
|
||||
let mut suggestion_groups_by_buffer = HashMap::default();
|
||||
for (buffer, mut suggestions) in suggestions_by_buffer {
|
||||
let mut suggestion_groups = Vec::<WorkflowSuggestionGroup>::new();
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
|
||||
// Sort suggestions by their range so that earlier, larger ranges come first
|
||||
suggestions.sort_by(|a, b| a.range().cmp(&b.range(), &snapshot));
|
||||
|
||||
// Merge overlapping suggestions
|
||||
suggestions.dedup_by(|a, b| b.try_merge(a, &snapshot));
|
||||
|
||||
// Create context ranges for each suggestion
|
||||
for suggestion in suggestions {
|
||||
let context_range = {
|
||||
let suggestion_point_range = suggestion.range().to_point(&snapshot);
|
||||
let start_row = suggestion_point_range.start.row.saturating_sub(5);
|
||||
let end_row =
|
||||
cmp::min(suggestion_point_range.end.row + 5, snapshot.max_point().row);
|
||||
let start = snapshot.anchor_before(Point::new(start_row, 0));
|
||||
let end =
|
||||
snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
|
||||
start..end
|
||||
};
|
||||
|
||||
if let Some(last_group) = suggestion_groups.last_mut() {
|
||||
if last_group
|
||||
.context_range
|
||||
.end
|
||||
.cmp(&context_range.start, &snapshot)
|
||||
.is_ge()
|
||||
{
|
||||
// Merge with the previous group if context ranges overlap
|
||||
last_group.context_range.end = context_range.end;
|
||||
last_group.suggestions.push(suggestion);
|
||||
} else {
|
||||
// Create a new group
|
||||
suggestion_groups.push(WorkflowSuggestionGroup {
|
||||
context_range,
|
||||
suggestions: vec![suggestion],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
let message_end = buffer.anchor_after(message.offset_range.end - 1);
|
||||
pending_patch.range.end = message_end;
|
||||
// Create the first group
|
||||
suggestion_groups.push(WorkflowSuggestionGroup {
|
||||
context_range,
|
||||
suggestions: vec![suggestion],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
pending_patch.range.end = text::Anchor::MAX;
|
||||
}
|
||||
|
||||
new_patches.push(pending_patch);
|
||||
suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
|
||||
}
|
||||
|
||||
new_patches
|
||||
Ok(suggestion_groups_by_buffer)
|
||||
}
|
||||
|
||||
pub fn pending_command_for_position(
|
||||
@@ -2053,7 +2160,7 @@ impl Context {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2208,11 +2315,11 @@ impl Context {
|
||||
let mut updated = Vec::new();
|
||||
let mut removed = Vec::new();
|
||||
for range in ranges {
|
||||
self.reparse_patches_in_range(range, &buffer, &mut updated, &mut removed, cx);
|
||||
self.reparse_workflow_steps_in_range(range, &buffer, &mut updated, &mut removed, cx);
|
||||
}
|
||||
|
||||
if !updated.is_empty() || !removed.is_empty() {
|
||||
cx.emit(ContextEvent::PatchesUpdated { removed, updated })
|
||||
cx.emit(ContextEvent::WorkflowStepsUpdated { removed, updated })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2718,24 +2825,6 @@ impl Context {
|
||||
}
|
||||
}
|
||||
|
||||
fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range<text::Anchor>) -> String {
|
||||
let mut is_start = true;
|
||||
let mut content = buffer
|
||||
.text_for_range(range)
|
||||
.map(|mut chunk| {
|
||||
if is_start {
|
||||
chunk = chunk.trim_start_matches('\n');
|
||||
if !chunk.is_empty() {
|
||||
is_start = false;
|
||||
}
|
||||
}
|
||||
chunk
|
||||
})
|
||||
.collect::<String>();
|
||||
content.truncate(content.trim_end().len());
|
||||
content
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ContextVersion {
|
||||
context: clock::Global,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::{AssistantEdit, MessageCacheMetadata};
|
||||
use super::{MessageCacheMetadata, WorkflowStepEdit};
|
||||
use crate::{
|
||||
assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
|
||||
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
assistant_panel, prompt_library, slash_command::file_command, CacheStatus, Context,
|
||||
ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
|
||||
WorkflowStepEditKind,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{
|
||||
@@ -14,7 +15,6 @@ use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
|
||||
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
|
||||
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
use serde_json::json;
|
||||
@@ -478,15 +478,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||
#[gpui::test]
|
||||
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
cx.update(prompt_library::init);
|
||||
let mut settings_store = cx.update(SettingsStore::test);
|
||||
cx.update(|cx| {
|
||||
settings_store
|
||||
.set_user_settings(
|
||||
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
cx.set_global(settings_store);
|
||||
cx.update(language::init);
|
||||
cx.update(Project::init_settings);
|
||||
@@ -528,7 +520,7 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
»",
|
||||
cx,
|
||||
);
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
@@ -547,17 +539,17 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
one
|
||||
two
|
||||
«
|
||||
<patch»",
|
||||
<step»",
|
||||
cx,
|
||||
);
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
<patch",
|
||||
<step",
|
||||
&[],
|
||||
cx,
|
||||
);
|
||||
@@ -571,24 +563,36 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
one
|
||||
two
|
||||
|
||||
<patch«>
|
||||
<step«>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>»",
|
||||
cx,
|
||||
);
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
«<patch>
|
||||
«<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>»",
|
||||
&[&[]],
|
||||
cx,
|
||||
);
|
||||
|
||||
// The full patch is added
|
||||
// The full suggestion is added
|
||||
edit(
|
||||
&context,
|
||||
"
|
||||
@@ -596,46 +600,51 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
one
|
||||
two
|
||||
|
||||
<patch>
|
||||
<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>«
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn one</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn one</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
</step>
|
||||
|
||||
also,»",
|
||||
cx,
|
||||
);
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
«<patch>
|
||||
«<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn one</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn one</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
»
|
||||
</step>»
|
||||
|
||||
also,",
|
||||
&[&[AssistantEdit {
|
||||
&[&[WorkflowStepEdit {
|
||||
path: "src/lib.rs".into(),
|
||||
kind: AssistantEditKind::InsertAfter {
|
||||
old_text: "fn one".into(),
|
||||
new_text: "fn two() {}".into(),
|
||||
kind: WorkflowStepEditKind::InsertAfter {
|
||||
search: "fn one".into(),
|
||||
description: "add a `two` function".into(),
|
||||
},
|
||||
}]],
|
||||
@@ -650,46 +659,51 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
one
|
||||
two
|
||||
|
||||
<patch>
|
||||
<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>«fn zero»</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>«fn zero»</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
</step>
|
||||
|
||||
also,",
|
||||
cx,
|
||||
);
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
«<patch>
|
||||
«<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn zero</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn zero</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
»
|
||||
</step>»
|
||||
|
||||
also,",
|
||||
&[&[AssistantEdit {
|
||||
&[&[WorkflowStepEdit {
|
||||
path: "src/lib.rs".into(),
|
||||
kind: AssistantEditKind::InsertAfter {
|
||||
old_text: "fn zero".into(),
|
||||
new_text: "fn two() {}".into(),
|
||||
kind: WorkflowStepEditKind::InsertAfter {
|
||||
search: "fn zero".into(),
|
||||
description: "add a `two` function".into(),
|
||||
},
|
||||
}]],
|
||||
@@ -701,24 +715,27 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
||||
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
||||
});
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
<patch>
|
||||
<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn zero</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn zero</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
</step>
|
||||
|
||||
also,",
|
||||
&[],
|
||||
@@ -729,31 +746,33 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
context.update(cx, |context, cx| {
|
||||
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
|
||||
});
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
«<patch>
|
||||
«<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn zero</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn zero</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
»
|
||||
</step>»
|
||||
|
||||
also,",
|
||||
&[&[AssistantEdit {
|
||||
&[&[WorkflowStepEdit {
|
||||
path: "src/lib.rs".into(),
|
||||
kind: AssistantEditKind::InsertAfter {
|
||||
old_text: "fn zero".into(),
|
||||
new_text: "fn two() {}".into(),
|
||||
kind: WorkflowStepEditKind::InsertAfter {
|
||||
search: "fn zero".into(),
|
||||
description: "add a `two` function".into(),
|
||||
},
|
||||
}]],
|
||||
@@ -773,31 +792,33 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
expect_patches(
|
||||
expect_steps(
|
||||
&deserialized_context,
|
||||
"
|
||||
|
||||
one
|
||||
two
|
||||
|
||||
«<patch>
|
||||
«<step>
|
||||
Add a second function
|
||||
|
||||
```rust
|
||||
fn two() {}
|
||||
```
|
||||
|
||||
<edit>
|
||||
<description>add a `two` function</description>
|
||||
<path>src/lib.rs</path>
|
||||
<operation>insert_after</operation>
|
||||
<old_text>fn zero</old_text>
|
||||
<new_text>
|
||||
fn two() {}
|
||||
</new_text>
|
||||
<search>fn zero</search>
|
||||
<description>add a `two` function</description>
|
||||
</edit>
|
||||
</patch>
|
||||
»
|
||||
</step>»
|
||||
|
||||
also,",
|
||||
&[&[AssistantEdit {
|
||||
&[&[WorkflowStepEdit {
|
||||
path: "src/lib.rs".into(),
|
||||
kind: AssistantEditKind::InsertAfter {
|
||||
old_text: "fn zero".into(),
|
||||
new_text: "fn two() {}".into(),
|
||||
kind: WorkflowStepEditKind::InsertAfter {
|
||||
search: "fn zero".into(),
|
||||
description: "add a `two` function".into(),
|
||||
},
|
||||
}]],
|
||||
@@ -813,58 +834,48 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
|
||||
cx.executor().run_until_parked();
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn expect_patches(
|
||||
fn expect_steps(
|
||||
context: &Model<Context>,
|
||||
expected_marked_text: &str,
|
||||
expected_suggestions: &[&[AssistantEdit]],
|
||||
expected_suggestions: &[&[WorkflowStepEdit]],
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let expected_marked_text = expected_marked_text.unindent();
|
||||
let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
|
||||
|
||||
let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
|
||||
context.update(cx, |context, cx| {
|
||||
let expected_marked_text = expected_marked_text.unindent();
|
||||
let (expected_text, expected_ranges) = marked_text_ranges(&expected_marked_text, false);
|
||||
context.buffer.read_with(cx, |buffer, _| {
|
||||
assert_eq!(buffer.text(), expected_text);
|
||||
let ranges = context
|
||||
.patches
|
||||
.workflow_steps
|
||||
.iter()
|
||||
.map(|entry| entry.range.to_offset(buffer))
|
||||
.collect::<Vec<_>>();
|
||||
(
|
||||
buffer.text(),
|
||||
ranges,
|
||||
context
|
||||
.patches
|
||||
.iter()
|
||||
.map(|step| step.edits.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
let marked = generate_marked_text(&expected_text, &ranges, false);
|
||||
assert_eq!(
|
||||
marked,
|
||||
expected_marked_text,
|
||||
"unexpected suggestion ranges. actual: {ranges:?}, expected: {expected_ranges:?}"
|
||||
);
|
||||
let suggestions = context
|
||||
.workflow_steps
|
||||
.iter()
|
||||
.map(|step| {
|
||||
step.edits
|
||||
.iter()
|
||||
.map(|edit| {
|
||||
let edit = edit.as_ref().unwrap();
|
||||
WorkflowStepEdit {
|
||||
path: edit.path.clone(),
|
||||
kind: edit.kind.clone(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(suggestions, expected_suggestions);
|
||||
});
|
||||
});
|
||||
|
||||
assert_eq!(buffer_text, expected_text);
|
||||
|
||||
let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
|
||||
assert_eq!(actual_marked_text, expected_marked_text);
|
||||
|
||||
assert_eq!(
|
||||
patches
|
||||
.iter()
|
||||
.map(|patch| {
|
||||
patch
|
||||
.iter()
|
||||
.map(|edit| {
|
||||
let edit = edit.as_ref().unwrap();
|
||||
AssistantEdit {
|
||||
path: edit.path.clone(),
|
||||
kind: edit.kind.clone(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
expected_suggestions
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,6 +82,13 @@ pub struct InlineAssistant {
|
||||
assists: HashMap<InlineAssistId, InlineAssist>,
|
||||
assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
|
||||
assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
|
||||
assist_observations: HashMap<
|
||||
InlineAssistId,
|
||||
(
|
||||
async_watch::Sender<AssistStatus>,
|
||||
async_watch::Receiver<AssistStatus>,
|
||||
),
|
||||
>,
|
||||
confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
|
||||
prompt_history: VecDeque<String>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
@@ -89,6 +96,19 @@ pub struct InlineAssistant {
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
pub enum AssistStatus {
|
||||
Idle,
|
||||
Started,
|
||||
Stopped,
|
||||
Finished,
|
||||
}
|
||||
|
||||
impl AssistStatus {
|
||||
pub fn is_done(&self) -> bool {
|
||||
matches!(self, Self::Stopped | Self::Finished)
|
||||
}
|
||||
}
|
||||
|
||||
impl Global for InlineAssistant {}
|
||||
|
||||
impl InlineAssistant {
|
||||
@@ -103,6 +123,7 @@ impl InlineAssistant {
|
||||
assists: HashMap::default(),
|
||||
assists_by_editor: HashMap::default(),
|
||||
assist_groups: HashMap::default(),
|
||||
assist_observations: HashMap::default(),
|
||||
confirmed_assists: HashMap::default(),
|
||||
prompt_history: VecDeque::default(),
|
||||
prompt_builder,
|
||||
@@ -246,7 +267,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: buffer.language().map(|language| language.name().to_proto()),
|
||||
language_name: buffer.language().map(|language| language.name()),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -767,7 +788,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -814,6 +835,17 @@ impl InlineAssistant {
|
||||
.insert(assist_id, confirmed_alternative);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the assist from the status updates map
|
||||
self.assist_observations.remove(&assist_id);
|
||||
}
|
||||
|
||||
pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
|
||||
let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
|
||||
return false;
|
||||
};
|
||||
codegen.update(cx, |this, cx| this.undo(cx));
|
||||
true
|
||||
}
|
||||
|
||||
fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
|
||||
@@ -1007,6 +1039,10 @@ impl InlineAssistant {
|
||||
codegen.start(user_prompt, assistant_panel_context, cx)
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
|
||||
tx.send(AssistStatus::Started).ok();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
|
||||
@@ -1017,6 +1053,25 @@ impl InlineAssistant {
|
||||
};
|
||||
|
||||
assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
|
||||
|
||||
if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
|
||||
tx.send(AssistStatus::Stopped).ok();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
|
||||
if let Some(assist) = self.assists.get(&assist_id) {
|
||||
match assist.codegen.read(cx).status(cx) {
|
||||
CodegenStatus::Idle => InlineAssistStatus::Idle,
|
||||
CodegenStatus::Pending => InlineAssistStatus::Pending,
|
||||
CodegenStatus::Done => InlineAssistStatus::Done,
|
||||
CodegenStatus::Error(_) => InlineAssistStatus::Error,
|
||||
}
|
||||
} else if self.confirmed_assists.contains_key(&assist_id) {
|
||||
InlineAssistStatus::Confirmed
|
||||
} else {
|
||||
InlineAssistStatus::Canceled
|
||||
}
|
||||
}
|
||||
|
||||
fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
|
||||
@@ -1202,6 +1257,42 @@ impl InlineAssistant {
|
||||
.collect();
|
||||
})
|
||||
}
|
||||
|
||||
pub fn observe_assist(
|
||||
&mut self,
|
||||
assist_id: InlineAssistId,
|
||||
) -> async_watch::Receiver<AssistStatus> {
|
||||
if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
|
||||
rx.clone()
|
||||
} else {
|
||||
let (tx, rx) = async_watch::channel(AssistStatus::Idle);
|
||||
self.assist_observations.insert(assist_id, (tx, rx.clone()));
|
||||
rx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum InlineAssistStatus {
|
||||
Idle,
|
||||
Pending,
|
||||
Done,
|
||||
Error,
|
||||
Confirmed,
|
||||
Canceled,
|
||||
}
|
||||
|
||||
impl InlineAssistStatus {
|
||||
pub(crate) fn is_pending(&self) -> bool {
|
||||
matches!(self, Self::Pending)
|
||||
}
|
||||
|
||||
pub(crate) fn is_confirmed(&self) -> bool {
|
||||
matches!(self, Self::Confirmed)
|
||||
}
|
||||
|
||||
pub(crate) fn is_done(&self) -> bool {
|
||||
matches!(self, Self::Done)
|
||||
}
|
||||
}
|
||||
|
||||
struct EditorInlineAssists {
|
||||
@@ -2187,7 +2278,7 @@ impl InlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
@@ -2199,6 +2290,8 @@ impl InlineAssist {
|
||||
|
||||
if assist.decorations.is_none() {
|
||||
this.finish_assist(assist_id, false, cx);
|
||||
} else if let Some(tx) = this.assist_observations.get(&assist_id) {
|
||||
tx.0.send(AssistStatus::Finished).ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -2861,7 +2954,7 @@ impl CodegenAlternative {
|
||||
model_provider: model_provider_id.to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,746 +0,0 @@
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use editor::ProposedChangesEditor;
|
||||
use futures::{future, TryFutureExt as _};
|
||||
use gpui::{AppContext, AsyncAppContext, Model, SharedString};
|
||||
use language::{AutoindentMode, Buffer, BufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct AssistantPatch {
|
||||
pub range: Range<language::Anchor>,
|
||||
pub title: SharedString,
|
||||
pub edits: Arc<[Result<AssistantEdit>]>,
|
||||
pub status: AssistantPatchStatus,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum AssistantPatchStatus {
|
||||
Pending,
|
||||
Ready,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) struct AssistantEdit {
|
||||
pub path: String,
|
||||
pub kind: AssistantEditKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum AssistantEditKind {
|
||||
Update {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: String,
|
||||
},
|
||||
Create {
|
||||
new_text: String,
|
||||
description: String,
|
||||
},
|
||||
InsertBefore {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: String,
|
||||
},
|
||||
InsertAfter {
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
description: String,
|
||||
},
|
||||
Delete {
|
||||
old_text: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) struct ResolvedPatch {
|
||||
pub edit_groups: HashMap<Model<Buffer>, Vec<ResolvedEditGroup>>,
|
||||
pub errors: Vec<AssistantPatchResolutionError>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ResolvedEditGroup {
|
||||
pub context_range: Range<language::Anchor>,
|
||||
pub edits: Vec<ResolvedEdit>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ResolvedEdit {
|
||||
range: Range<language::Anchor>,
|
||||
new_text: String,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) struct AssistantPatchResolutionError {
|
||||
pub edit_ix: usize,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
// A measure of the currently quality of an in-progress fuzzy search.
|
||||
//
|
||||
// Uses 60 bits to store a numeric cost, and 4 bits to store the preceding
|
||||
// operation in the search.
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
score: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(score: u32, direction: SearchDirection) -> Self {
|
||||
Self { score, direction }
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvedPatch {
|
||||
pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut AppContext) {
|
||||
for (buffer, groups) in &self.edit_groups {
|
||||
let branch = editor.branch_buffer_for_base(buffer).unwrap();
|
||||
Self::apply_edit_groups(groups, &branch, cx);
|
||||
}
|
||||
editor.recalculate_all_buffer_diffs();
|
||||
}
|
||||
|
||||
fn apply_edit_groups(
|
||||
groups: &Vec<ResolvedEditGroup>,
|
||||
buffer: &Model<Buffer>,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let mut edits = Vec::new();
|
||||
for group in groups {
|
||||
for suggestion in &group.edits {
|
||||
edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
|
||||
}
|
||||
}
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(
|
||||
edits,
|
||||
Some(AutoindentMode::Block {
|
||||
original_indent_columns: Vec::new(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvedEdit {
|
||||
pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
|
||||
let range = &self.range;
|
||||
let other_range = &other.range;
|
||||
|
||||
// Don't merge if we don't contain the other suggestion.
|
||||
if range.start.cmp(&other_range.start, buffer).is_gt()
|
||||
|| range.end.cmp(&other_range.end, buffer).is_lt()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(description) = &mut self.description {
|
||||
if let Some(other_description) = &other.description {
|
||||
description.push('\n');
|
||||
description.push_str(other_description);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantEdit {
|
||||
pub fn new(
|
||||
path: Option<String>,
|
||||
operation: Option<String>,
|
||||
old_text: Option<String>,
|
||||
new_text: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let path = path.ok_or_else(|| anyhow!("missing path"))?;
|
||||
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
|
||||
|
||||
let kind = match operation.as_str() {
|
||||
"update" => AssistantEditKind::Update {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"insert_before" => AssistantEditKind::InsertBefore {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"insert_after" => AssistantEditKind::InsertAfter {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"delete" => AssistantEditKind::Delete {
|
||||
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
|
||||
},
|
||||
"create" => AssistantEditKind::Create {
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
|
||||
},
|
||||
_ => Err(anyhow!("unknown operation {operation:?}"))?,
|
||||
};
|
||||
|
||||
Ok(Self { path, kind })
|
||||
}
|
||||
|
||||
pub async fn resolve(
|
||||
&self,
|
||||
project: Model<Project>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<(Model<Buffer>, ResolvedEdit)> {
|
||||
let path = self.path.clone();
|
||||
let kind = self.kind.clone();
|
||||
let buffer = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(Path::new(&path), cx)
|
||||
.or_else(|| {
|
||||
// If we couldn't find a project path for it, put it in the active worktree
|
||||
// so that when we create the buffer, it can be saved.
|
||||
let worktree = project
|
||||
.active_entry()
|
||||
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
|
||||
.or_else(|| project.worktrees(cx).next())?;
|
||||
let worktree = worktree.read(cx);
|
||||
|
||||
Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: Arc::from(Path::new(&path)),
|
||||
})
|
||||
})
|
||||
.with_context(|| format!("worktree not found for {:?}", path))?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
|
||||
let suggestion = cx
|
||||
.background_executor()
|
||||
.spawn(async move { kind.resolve(&snapshot) })
|
||||
.await;
|
||||
|
||||
Ok((buffer, suggestion))
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantEditKind {
|
||||
fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
|
||||
match self {
|
||||
Self::Update {
|
||||
old_text,
|
||||
new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
ResolvedEdit {
|
||||
range,
|
||||
new_text,
|
||||
description: Some(description),
|
||||
}
|
||||
}
|
||||
Self::Create {
|
||||
new_text,
|
||||
description,
|
||||
} => ResolvedEdit {
|
||||
range: text::Anchor::MIN..text::Anchor::MAX,
|
||||
description: Some(description),
|
||||
new_text,
|
||||
},
|
||||
Self::InsertBefore {
|
||||
old_text,
|
||||
mut new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
new_text.push('\n');
|
||||
ResolvedEdit {
|
||||
range: range.start..range.start,
|
||||
new_text,
|
||||
description: Some(description),
|
||||
}
|
||||
}
|
||||
Self::InsertAfter {
|
||||
old_text,
|
||||
mut new_text,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
new_text.insert(0, '\n');
|
||||
ResolvedEdit {
|
||||
range: range.end..range.end,
|
||||
new_text,
|
||||
description: Some(description),
|
||||
}
|
||||
}
|
||||
Self::Delete { old_text } => {
|
||||
let range = Self::resolve_location(&snapshot, &old_text);
|
||||
ResolvedEdit {
|
||||
range,
|
||||
new_text: String::new(),
|
||||
description: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const WHITESPACE_INSERTION_COST: u32 = 1;
|
||||
const DELETION_COST: u32 = 3;
|
||||
const WHITESPACE_DELETION_COST: u32 = 1;
|
||||
const EQUALITY_BONUS: u32 = 5;
|
||||
|
||||
struct Matrix {
|
||||
cols: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
fn new(rows: usize, cols: usize) -> Self {
|
||||
Matrix {
|
||||
cols,
|
||||
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
|
||||
self.data[row * self.cols + col] = cost;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_len = buffer.len();
|
||||
let query_len = search_query.len();
|
||||
let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
|
||||
|
||||
for (row, query_byte) in search_query.bytes().enumerate() {
|
||||
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
|
||||
let deletion_cost = if query_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_DELETION_COST
|
||||
} else {
|
||||
DELETION_COST
|
||||
};
|
||||
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_INSERTION_COST
|
||||
} else {
|
||||
INSERTION_COST
|
||||
};
|
||||
|
||||
let up = SearchState::new(
|
||||
matrix.get(row, col + 1).score.saturating_sub(deletion_cost),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
matrix
|
||||
.get(row + 1, col)
|
||||
.score
|
||||
.saturating_sub(insertion_cost),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_byte == *buffer_byte {
|
||||
matrix.get(row, col).score.saturating_add(EQUALITY_BONUS)
|
||||
} else {
|
||||
matrix
|
||||
.get(row, col)
|
||||
.score
|
||||
.saturating_sub(deletion_cost + insertion_cost)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
matrix.set(row + 1, col + 1, up.max(left).max(diagonal));
|
||||
}
|
||||
}
|
||||
|
||||
// Traceback to find the best match
|
||||
let mut best_buffer_end = buffer_len;
|
||||
let mut best_score = 0;
|
||||
for col in 1..=buffer_len {
|
||||
let score = matrix.get(query_len, col).score;
|
||||
if score > best_score {
|
||||
best_score = score;
|
||||
best_buffer_end = col;
|
||||
}
|
||||
}
|
||||
|
||||
let mut query_ix = query_len;
|
||||
let mut buffer_ix = best_buffer_end;
|
||||
while query_ix > 0 && buffer_ix > 0 {
|
||||
let current = matrix.get(query_ix, buffer_ix);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_ix -= 1;
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_ix -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
|
||||
start.column = 0;
|
||||
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
|
||||
if end.column > 0 {
|
||||
end.column = buffer.line_len(end.row);
|
||||
}
|
||||
|
||||
buffer.anchor_after(start)..buffer.anchor_before(end)
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantPatch {
|
||||
pub(crate) async fn resolve(
|
||||
&self,
|
||||
project: Model<Project>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> ResolvedPatch {
|
||||
let mut resolve_tasks = Vec::new();
|
||||
for (ix, edit) in self.edits.iter().enumerate() {
|
||||
if let Ok(edit) = edit.as_ref() {
|
||||
resolve_tasks.push(
|
||||
edit.resolve(project.clone(), cx.clone())
|
||||
.map_err(move |error| (ix, error)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let edits = future::join_all(resolve_tasks).await;
|
||||
let mut errors = Vec::new();
|
||||
let mut edits_by_buffer = HashMap::default();
|
||||
for entry in edits {
|
||||
match entry {
|
||||
Ok((buffer, edit)) => {
|
||||
edits_by_buffer
|
||||
.entry(buffer)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(edit);
|
||||
}
|
||||
Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
|
||||
edit_ix,
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Expand the context ranges of each edit and group edits with overlapping context ranges.
|
||||
let mut edit_groups_by_buffer = HashMap::default();
|
||||
for (buffer, edits) in edits_by_buffer {
|
||||
if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
|
||||
edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
|
||||
}
|
||||
}
|
||||
|
||||
ResolvedPatch {
|
||||
edit_groups: edit_groups_by_buffer,
|
||||
errors,
|
||||
}
|
||||
}
|
||||
|
||||
fn group_edits(
|
||||
mut edits: Vec<ResolvedEdit>,
|
||||
snapshot: &text::BufferSnapshot,
|
||||
) -> Vec<ResolvedEditGroup> {
|
||||
let mut edit_groups = Vec::<ResolvedEditGroup>::new();
|
||||
// Sort edits by their range so that earlier, larger ranges come first
|
||||
edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
|
||||
|
||||
// Merge overlapping edits
|
||||
edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
|
||||
|
||||
// Create context ranges for each edit
|
||||
for edit in edits {
|
||||
let context_range = {
|
||||
let edit_point_range = edit.range.to_point(&snapshot);
|
||||
let start_row = edit_point_range.start.row.saturating_sub(5);
|
||||
let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
|
||||
let start = snapshot.anchor_before(Point::new(start_row, 0));
|
||||
let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
|
||||
start..end
|
||||
};
|
||||
|
||||
if let Some(last_group) = edit_groups.last_mut() {
|
||||
if last_group
|
||||
.context_range
|
||||
.end
|
||||
.cmp(&context_range.start, &snapshot)
|
||||
.is_ge()
|
||||
{
|
||||
// Merge with the previous group if context ranges overlap
|
||||
last_group.context_range.end = context_range.end;
|
||||
last_group.edits.push(edit);
|
||||
} else {
|
||||
// Create a new group
|
||||
edit_groups.push(ResolvedEditGroup {
|
||||
context_range,
|
||||
edits: vec![edit],
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Create the first group
|
||||
edit_groups.push(ResolvedEditGroup {
|
||||
context_range,
|
||||
edits: vec![edit],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
edit_groups
|
||||
}
|
||||
|
||||
pub fn path_count(&self) -> usize {
|
||||
self.paths().count()
|
||||
}
|
||||
|
||||
pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
|
||||
let mut prev_path = None;
|
||||
self.edits.iter().filter_map(move |edit| {
|
||||
if let Ok(edit) = edit {
|
||||
let path = Some(edit.path.as_str());
|
||||
if path != prev_path {
|
||||
prev_path = path;
|
||||
return path;
|
||||
}
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for AssistantPatch {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.range == other.range
|
||||
&& self.title == other.title
|
||||
&& Arc::ptr_eq(&self.edits, &other.edits)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for AssistantPatch {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context};
|
||||
use language::{
|
||||
language_settings::AllLanguageSettings, Language, LanguageConfig, LanguageMatcher,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use text::{OffsetRangeExt, Point};
|
||||
use ui::BorrowAppContext;
|
||||
use unindent::Unindent as _;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_location(cx: &mut AppContext) {
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
" Lorem\n",
|
||||
" ipsum\n",
|
||||
" dolor sit amet\n",
|
||||
" consecteur",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
AssistantEditKind::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
|
||||
Point::new(1, 0)..Point::new(2, 18)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
"fn foo1(a: usize) -> usize {\n",
|
||||
" 40\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"fn foo2(b: usize) -> usize {\n",
|
||||
" 42\n",
|
||||
"}\n",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
AssistantEditKind::resolve_location(&snapshot, "fn foo1(b: usize) {\n40\n}")
|
||||
.to_point(&snapshot),
|
||||
Point::new(0, 0)..Point::new(2, 1)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
"fn main() {\n",
|
||||
" Foo\n",
|
||||
" .bar()\n",
|
||||
" .baz()\n",
|
||||
" .qux()\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"fn foo2(b: usize) -> usize {\n",
|
||||
" 42\n",
|
||||
"}\n",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
AssistantEditKind::resolve_location(&snapshot, "Foo.bar.baz.qux()")
|
||||
.to_point(&snapshot),
|
||||
Point::new(1, 0)..Point::new(4, 14)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_edits(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
cx.update_global::<SettingsStore, _>(|settings, cx| {
|
||||
settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
|
||||
});
|
||||
|
||||
assert_edits(
|
||||
"
|
||||
/// A person
|
||||
struct Person {
|
||||
name: String,
|
||||
age: usize,
|
||||
}
|
||||
|
||||
/// A dog
|
||||
struct Dog {
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
vec![
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
name: String,
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
"
|
||||
.unindent(),
|
||||
description: "".into(),
|
||||
},
|
||||
AssistantEditKind::Update {
|
||||
old_text: "
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
new_text: "
|
||||
fn name(&self) -> String {
|
||||
format!(\"{} {}\", self.first_name, self.last_name)
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
description: "".into(),
|
||||
},
|
||||
],
|
||||
"
|
||||
/// A person
|
||||
struct Person {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: usize,
|
||||
}
|
||||
|
||||
/// A dog
|
||||
struct Dog {
|
||||
weight: f32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
fn name(&self) -> String {
|
||||
format!(\"{} {}\", self.first_name, self.last_name)
|
||||
}
|
||||
}
|
||||
"
|
||||
.unindent(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_edits(
|
||||
old_text: String,
|
||||
edits: Vec<AssistantEditKind>,
|
||||
new_text: String,
|
||||
cx: &mut AppContext,
|
||||
) {
|
||||
let buffer =
|
||||
cx.new_model(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let resolved_edits = edits
|
||||
.into_iter()
|
||||
.map(|kind| kind.resolve(&snapshot))
|
||||
.collect();
|
||||
let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
|
||||
ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
|
||||
let actual_new_text = buffer.read(cx).text();
|
||||
pretty_assertions::assert_eq!(actual_new_text, new_text);
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(language::tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -521,9 +521,9 @@ impl PromptLibrary {
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_use_modal_editing(false);
|
||||
editor.set_current_line_highlight(Some(CurrentLineHighlight::None));
|
||||
editor.set_completion_provider(Some(Box::new(
|
||||
editor.set_completion_provider(Box::new(
|
||||
SlashCommandCompletionProvider::new(None, None),
|
||||
)));
|
||||
));
|
||||
if focus {
|
||||
editor.focus(cx);
|
||||
}
|
||||
|
||||
@@ -45,6 +45,15 @@ pub struct ProjectSlashCommandPromptContext {
|
||||
pub context_buffer: String,
|
||||
}
|
||||
|
||||
/// Context required to generate a workflow step resolution prompt.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct StepResolutionContext {
|
||||
/// The full context, including <step>...</step> tags
|
||||
pub workflow_context: String,
|
||||
/// The text of the specific step from the context to resolve
|
||||
pub step_to_resolve: String,
|
||||
}
|
||||
|
||||
pub struct PromptLoadingParams<'a> {
|
||||
pub fs: Arc<dyn Fs>,
|
||||
pub repo_path: Option<PathBuf>,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use super::create_label_for_command;
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{
|
||||
AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
|
||||
@@ -7,9 +6,9 @@ use assistant_slash_command::{
|
||||
use collections::HashMap;
|
||||
use context_servers::{
|
||||
manager::{ContextServer, ContextServerManager},
|
||||
types::Prompt,
|
||||
protocol::PromptInfo,
|
||||
};
|
||||
use gpui::{AppContext, Task, WeakView, WindowContext};
|
||||
use gpui::{Task, WeakView, WindowContext};
|
||||
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
@@ -19,11 +18,11 @@ use workspace::Workspace;
|
||||
|
||||
pub struct ContextServerSlashCommand {
|
||||
server_id: String,
|
||||
prompt: Prompt,
|
||||
prompt: PromptInfo,
|
||||
}
|
||||
|
||||
impl ContextServerSlashCommand {
|
||||
pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
|
||||
pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
|
||||
Self {
|
||||
server_id: server.id.clone(),
|
||||
prompt,
|
||||
@@ -36,28 +35,12 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||
self.prompt.name.clone()
|
||||
}
|
||||
|
||||
fn label(&self, cx: &AppContext) -> language::CodeLabel {
|
||||
let mut parts = vec![self.prompt.name.as_str()];
|
||||
if let Some(args) = &self.prompt.arguments {
|
||||
if let Some(arg) = args.first() {
|
||||
parts.push(arg.name.as_str());
|
||||
}
|
||||
}
|
||||
create_label_for_command(&parts[0], &parts[1..], cx)
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
match &self.prompt.description {
|
||||
Some(desc) => desc.clone(),
|
||||
None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
|
||||
}
|
||||
format!("Run context server command: {}", self.prompt.name)
|
||||
}
|
||||
|
||||
fn menu_text(&self) -> String {
|
||||
match &self.prompt.description {
|
||||
Some(desc) => desc.clone(),
|
||||
None => format!("Run '{}' from {}", self.prompt.name, self.server_id),
|
||||
}
|
||||
format!("Run '{}' from {}", self.prompt.name, self.server_id)
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
@@ -171,7 +154,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||
}
|
||||
}
|
||||
|
||||
fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String, String)> {
|
||||
fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
|
||||
if arguments.is_empty() {
|
||||
return Err(anyhow!("No arguments given"));
|
||||
}
|
||||
@@ -187,7 +170,7 @@ fn completion_argument(prompt: &Prompt, arguments: &[String]) -> Result<(String,
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<String, String>> {
|
||||
fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
|
||||
match &prompt.arguments {
|
||||
Some(args) if args.len() > 1 => Err(anyhow!(
|
||||
"Prompt has more than one argument, which is not supported"
|
||||
@@ -216,7 +199,7 @@ fn prompt_arguments(prompt: &Prompt, arguments: &[String]) -> Result<HashMap<Str
|
||||
/// MCP servers can return prompts with multiple arguments. Since we only
|
||||
/// support one argument, we ignore all others. This is the necessary predicate
|
||||
/// for this.
|
||||
pub fn acceptable_prompt(prompt: &Prompt) -> bool {
|
||||
pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
|
||||
match &prompt.arguments {
|
||||
None => true,
|
||||
Some(args) if args.len() <= 1 => true,
|
||||
|
||||
@@ -18,8 +18,6 @@ pub(crate) struct WorkflowSlashCommand {
|
||||
}
|
||||
|
||||
impl WorkflowSlashCommand {
|
||||
pub const NAME: &'static str = "workflow";
|
||||
|
||||
pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
|
||||
Self { prompt_builder }
|
||||
}
|
||||
@@ -27,7 +25,7 @@ impl WorkflowSlashCommand {
|
||||
|
||||
impl SlashCommand for WorkflowSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
Self::NAME.into()
|
||||
"workflow".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
|
||||
@@ -38,10 +38,7 @@ impl Settings for SlashCommandSettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _cx: &mut AppContext) -> Result<Self> {
|
||||
SettingsSources::<Self::FileContent>::json_merge_with(
|
||||
[sources.default]
|
||||
.into_iter()
|
||||
.chain(sources.user)
|
||||
.chain(sources.server),
|
||||
[sources.default].into_iter().chain(sources.user),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ impl TerminalInlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
|
||||
507
crates/assistant/src/workflow.rs
Normal file
507
crates/assistant/src/workflow.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
use crate::{AssistantPanel, InlineAssistId, InlineAssistant};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use editor::Editor;
|
||||
use gpui::AsyncAppContext;
|
||||
use gpui::{Model, Task, UpdateGlobal as _, View, WeakView, WindowContext};
|
||||
use language::{Buffer, BufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
use text::Bias;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct WorkflowStep {
|
||||
pub range: Range<language::Anchor>,
|
||||
pub leading_tags_end: text::Anchor,
|
||||
pub trailing_tag_start: Option<text::Anchor>,
|
||||
pub edits: Arc<[Result<WorkflowStepEdit>]>,
|
||||
pub resolution_task: Option<Task<()>>,
|
||||
pub resolution: Option<Arc<Result<WorkflowStepResolution>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) struct WorkflowStepEdit {
|
||||
pub path: String,
|
||||
pub kind: WorkflowStepEditKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) struct WorkflowStepResolution {
|
||||
pub title: String,
|
||||
pub suggestion_groups: HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct WorkflowSuggestionGroup {
|
||||
pub context_range: Range<language::Anchor>,
|
||||
pub suggestions: Vec<WorkflowSuggestion>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum WorkflowSuggestion {
|
||||
Update {
|
||||
range: Range<language::Anchor>,
|
||||
description: String,
|
||||
},
|
||||
CreateFile {
|
||||
description: String,
|
||||
},
|
||||
InsertBefore {
|
||||
position: language::Anchor,
|
||||
description: String,
|
||||
},
|
||||
InsertAfter {
|
||||
position: language::Anchor,
|
||||
description: String,
|
||||
},
|
||||
Delete {
|
||||
range: Range<language::Anchor>,
|
||||
},
|
||||
}
|
||||
|
||||
impl WorkflowSuggestion {
|
||||
pub fn range(&self) -> Range<language::Anchor> {
|
||||
match self {
|
||||
Self::Update { range, .. } => range.clone(),
|
||||
Self::CreateFile { .. } => language::Anchor::MIN..language::Anchor::MAX,
|
||||
Self::InsertBefore { position, .. } | Self::InsertAfter { position, .. } => {
|
||||
*position..*position
|
||||
}
|
||||
Self::Delete { range, .. } => range.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn description(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Update { description, .. }
|
||||
| Self::CreateFile { description }
|
||||
| Self::InsertBefore { description, .. }
|
||||
| Self::InsertAfter { description, .. } => Some(description),
|
||||
Self::Delete { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn description_mut(&mut self) -> Option<&mut String> {
|
||||
match self {
|
||||
Self::Update { description, .. }
|
||||
| Self::CreateFile { description }
|
||||
| Self::InsertBefore { description, .. }
|
||||
| Self::InsertAfter { description, .. } => Some(description),
|
||||
Self::Delete { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_merge(&mut self, other: &Self, buffer: &BufferSnapshot) -> bool {
|
||||
let range = self.range();
|
||||
let other_range = other.range();
|
||||
|
||||
// Don't merge if we don't contain the other suggestion.
|
||||
if range.start.cmp(&other_range.start, buffer).is_gt()
|
||||
|| range.end.cmp(&other_range.end, buffer).is_lt()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(description) = self.description_mut() {
|
||||
if let Some(other_description) = other.description() {
|
||||
description.push('\n');
|
||||
description.push_str(other_description);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn show(
|
||||
&self,
|
||||
editor: &View<Editor>,
|
||||
excerpt_id: editor::ExcerptId,
|
||||
workspace: &WeakView<Workspace>,
|
||||
assistant_panel: &View<AssistantPanel>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Option<InlineAssistId> {
|
||||
let mut initial_transaction_id = None;
|
||||
let initial_prompt;
|
||||
let suggestion_range;
|
||||
let buffer = editor.read(cx).buffer().clone();
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
|
||||
match self {
|
||||
Self::Update {
|
||||
range, description, ..
|
||||
} => {
|
||||
initial_prompt = description.clone();
|
||||
suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
|
||||
..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
|
||||
}
|
||||
Self::CreateFile { description } => {
|
||||
initial_prompt = description.clone();
|
||||
suggestion_range = editor::Anchor::min()..editor::Anchor::min();
|
||||
}
|
||||
Self::InsertBefore {
|
||||
position,
|
||||
description,
|
||||
..
|
||||
} => {
|
||||
let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
|
||||
initial_prompt = description.clone();
|
||||
suggestion_range = buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
let line_start = buffer.insert_empty_line(position, true, true, cx);
|
||||
initial_transaction_id = buffer.end_transaction(cx);
|
||||
buffer.refresh_preview(cx);
|
||||
|
||||
let line_start = buffer.read(cx).anchor_before(line_start);
|
||||
line_start..line_start
|
||||
});
|
||||
}
|
||||
Self::InsertAfter {
|
||||
position,
|
||||
description,
|
||||
..
|
||||
} => {
|
||||
let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
|
||||
initial_prompt = description.clone();
|
||||
suggestion_range = buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
let line_start = buffer.insert_empty_line(position, true, true, cx);
|
||||
initial_transaction_id = buffer.end_transaction(cx);
|
||||
buffer.refresh_preview(cx);
|
||||
|
||||
let line_start = buffer.read(cx).anchor_before(line_start);
|
||||
line_start..line_start
|
||||
});
|
||||
}
|
||||
Self::Delete { range, .. } => {
|
||||
initial_prompt = "Delete".to_string();
|
||||
suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
|
||||
..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
|
||||
}
|
||||
}
|
||||
|
||||
InlineAssistant::update_global(cx, |inline_assistant, cx| {
|
||||
Some(inline_assistant.suggest_assist(
|
||||
editor,
|
||||
suggestion_range,
|
||||
initial_prompt,
|
||||
initial_transaction_id,
|
||||
false,
|
||||
Some(workspace.clone()),
|
||||
Some(assistant_panel),
|
||||
cx,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl WorkflowStepEdit {
|
||||
pub fn new(
|
||||
path: Option<String>,
|
||||
operation: Option<String>,
|
||||
search: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let path = path.ok_or_else(|| anyhow!("missing path"))?;
|
||||
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
|
||||
|
||||
let kind = match operation.as_str() {
|
||||
"update" => WorkflowStepEditKind::Update {
|
||||
search: search.ok_or_else(|| anyhow!("missing search"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"insert_before" => WorkflowStepEditKind::InsertBefore {
|
||||
search: search.ok_or_else(|| anyhow!("missing search"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"insert_after" => WorkflowStepEditKind::InsertAfter {
|
||||
search: search.ok_or_else(|| anyhow!("missing search"))?,
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
"delete" => WorkflowStepEditKind::Delete {
|
||||
search: search.ok_or_else(|| anyhow!("missing search"))?,
|
||||
},
|
||||
"create" => WorkflowStepEditKind::Create {
|
||||
description: description.ok_or_else(|| anyhow!("missing description"))?,
|
||||
},
|
||||
_ => Err(anyhow!("unknown operation {operation:?}"))?,
|
||||
};
|
||||
|
||||
Ok(Self { path, kind })
|
||||
}
|
||||
|
||||
pub async fn resolve(
|
||||
&self,
|
||||
project: Model<Project>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<(Model<Buffer>, super::WorkflowSuggestion)> {
|
||||
let path = self.path.clone();
|
||||
let kind = self.kind.clone();
|
||||
let buffer = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(Path::new(&path), cx)
|
||||
.or_else(|| {
|
||||
// If we couldn't find a project path for it, put it in the active worktree
|
||||
// so that when we create the buffer, it can be saved.
|
||||
let worktree = project
|
||||
.active_entry()
|
||||
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
|
||||
.or_else(|| project.worktrees(cx).next())?;
|
||||
let worktree = worktree.read(cx);
|
||||
|
||||
Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: Arc::from(Path::new(&path)),
|
||||
})
|
||||
})
|
||||
.with_context(|| format!("worktree not found for {:?}", path))?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
|
||||
let suggestion = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
match kind {
|
||||
WorkflowStepEditKind::Update {
|
||||
search,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &search);
|
||||
WorkflowSuggestion::Update { range, description }
|
||||
}
|
||||
WorkflowStepEditKind::Create { description } => {
|
||||
WorkflowSuggestion::CreateFile { description }
|
||||
}
|
||||
WorkflowStepEditKind::InsertBefore {
|
||||
search,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &search);
|
||||
WorkflowSuggestion::InsertBefore {
|
||||
position: range.start,
|
||||
description,
|
||||
}
|
||||
}
|
||||
WorkflowStepEditKind::InsertAfter {
|
||||
search,
|
||||
description,
|
||||
} => {
|
||||
let range = Self::resolve_location(&snapshot, &search);
|
||||
WorkflowSuggestion::InsertAfter {
|
||||
position: range.end,
|
||||
description,
|
||||
}
|
||||
}
|
||||
WorkflowStepEditKind::Delete { search } => {
|
||||
let range = Self::resolve_location(&snapshot, &search);
|
||||
WorkflowSuggestion::Delete { range }
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok((buffer, suggestion))
|
||||
}
|
||||
|
||||
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
|
||||
const INSERTION_SCORE: f64 = -1.0;
|
||||
const DELETION_SCORE: f64 = -1.0;
|
||||
const REPLACEMENT_SCORE: f64 = -1.0;
|
||||
const EQUALITY_SCORE: f64 = 5.0;
|
||||
|
||||
struct Matrix {
|
||||
cols: usize,
|
||||
data: Vec<f64>,
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
fn new(rows: usize, cols: usize) -> Self {
|
||||
Matrix {
|
||||
cols,
|
||||
data: vec![0.0; rows * cols],
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> f64 {
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, value: f64) {
|
||||
self.data[row * self.cols + col] = value;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_len = buffer.len();
|
||||
let query_len = search_query.len();
|
||||
let mut matrix = Matrix::new(query_len + 1, buffer_len + 1);
|
||||
|
||||
for (i, query_byte) in search_query.bytes().enumerate() {
|
||||
for (j, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
|
||||
let match_score = if query_byte == *buffer_byte {
|
||||
EQUALITY_SCORE
|
||||
} else {
|
||||
REPLACEMENT_SCORE
|
||||
};
|
||||
let up = matrix.get(i + 1, j) + DELETION_SCORE;
|
||||
let left = matrix.get(i, j + 1) + INSERTION_SCORE;
|
||||
let diagonal = matrix.get(i, j) + match_score;
|
||||
let score = up.max(left.max(diagonal)).max(0.);
|
||||
matrix.set(i + 1, j + 1, score);
|
||||
}
|
||||
}
|
||||
|
||||
// Traceback to find the best match
|
||||
let mut best_buffer_end = buffer_len;
|
||||
let mut best_score = 0.0;
|
||||
for col in 1..=buffer_len {
|
||||
let score = matrix.get(query_len, col);
|
||||
if score > best_score {
|
||||
best_score = score;
|
||||
best_buffer_end = col;
|
||||
}
|
||||
}
|
||||
|
||||
let mut query_ix = query_len;
|
||||
let mut buffer_ix = best_buffer_end;
|
||||
while query_ix > 0 && buffer_ix > 0 {
|
||||
let current = matrix.get(query_ix, buffer_ix);
|
||||
let up = matrix.get(query_ix - 1, buffer_ix);
|
||||
let left = matrix.get(query_ix, buffer_ix - 1);
|
||||
if current == left + INSERTION_SCORE {
|
||||
buffer_ix -= 1;
|
||||
} else if current == up + DELETION_SCORE {
|
||||
query_ix -= 1;
|
||||
} else {
|
||||
query_ix -= 1;
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
|
||||
start.column = 0;
|
||||
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
|
||||
end.column = buffer.line_len(end.row);
|
||||
|
||||
buffer.anchor_after(start)..buffer.anchor_before(end)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "operation")]
|
||||
pub enum WorkflowStepEditKind {
|
||||
/// Rewrites the specified text entirely based on the given description.
|
||||
/// This operation completely replaces the given text.
|
||||
Update {
|
||||
/// A string in the source text to apply the update to.
|
||||
search: String,
|
||||
/// A brief description of the transformation to apply to the symbol.
|
||||
description: String,
|
||||
},
|
||||
/// Creates a new file with the given path based on the provided description.
|
||||
/// This operation adds a new file to the codebase.
|
||||
Create {
|
||||
/// A brief description of the file to be created.
|
||||
description: String,
|
||||
},
|
||||
/// Inserts text before the specified text in the source file.
|
||||
InsertBefore {
|
||||
/// A string in the source text to insert text before.
|
||||
search: String,
|
||||
/// A brief description of how the new text should be generated.
|
||||
description: String,
|
||||
},
|
||||
/// Inserts text after the specified text in the source file.
|
||||
InsertAfter {
|
||||
/// A string in the source text to insert text after.
|
||||
search: String,
|
||||
/// A brief description of how the new text should be generated.
|
||||
description: String,
|
||||
},
|
||||
/// Deletes the specified symbol from the containing file.
|
||||
Delete {
|
||||
/// A string in the source text to delete.
|
||||
search: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context};
|
||||
use text::{OffsetRangeExt, Point};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_resolve_location(cx: &mut AppContext) {
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
" Lorem\n",
|
||||
" ipsum\n",
|
||||
" dolor sit amet\n",
|
||||
" consecteur",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
WorkflowStepEdit::resolve_location(&snapshot, "ipsum\ndolor").to_point(&snapshot),
|
||||
Point::new(1, 0)..Point::new(2, 18)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
"fn foo1(a: usize) -> usize {\n",
|
||||
" 42\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"fn foo2(b: usize) -> usize {\n",
|
||||
" 42\n",
|
||||
"}\n",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
WorkflowStepEdit::resolve_location(&snapshot, "fn foo1(b: usize) {\n42\n}")
|
||||
.to_point(&snapshot),
|
||||
Point::new(0, 0)..Point::new(2, 1)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::local(
|
||||
concat!(
|
||||
"fn main() {\n",
|
||||
" Foo\n",
|
||||
" .bar()\n",
|
||||
" .baz()\n",
|
||||
" .qux()\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"fn foo2(b: usize) -> usize {\n",
|
||||
" 42\n",
|
||||
"}\n",
|
||||
),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
WorkflowStepEdit::resolve_location(&snapshot, "Foo.bar.baz.qux()")
|
||||
.to_point(&snapshot),
|
||||
Point::new(1, 0)..Point::new(4, 14)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,7 @@ impl Settings for AutoUpdateSetting {
|
||||
type FileContent = Option<AutoUpdateSettingContent>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
let auto_update = [sources.server, sources.release_channel, sources.user]
|
||||
let auto_update = [sources.release_channel, sources.user]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
@@ -464,7 +464,6 @@ impl AutoUpdater {
|
||||
smol::fs::create_dir_all(&platform_dir).await.ok();
|
||||
|
||||
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
|
||||
|
||||
if smol::fs::metadata(&version_path).await.is_err() {
|
||||
log::info!("downloading zed-remote-server {os} {arch}");
|
||||
download_remote_server_binary(&version_path, release, client, cx).await?;
|
||||
|
||||
@@ -1178,7 +1178,7 @@ impl Room {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.joined_projects.retain(|project| {
|
||||
if let Some(project) = project.upgrade() {
|
||||
!project.read(cx).is_disconnected(cx)
|
||||
!project.read(cx).is_disconnected()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ postage.workspace = true
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
rpc = { workspace = true, features = ["gpui"] }
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -4,7 +4,6 @@ pub mod test;
|
||||
mod socks;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
|
||||
use anyhow::{anyhow, bail, Context as _, Result};
|
||||
use async_recursion::async_recursion;
|
||||
@@ -142,7 +141,6 @@ impl Settings for ProxySettings {
|
||||
Ok(Self {
|
||||
proxy: sources
|
||||
.user
|
||||
.or(sources.server)
|
||||
.and_then(|value| value.proxy.clone())
|
||||
.or(sources.default.proxy.clone()),
|
||||
})
|
||||
@@ -474,21 +472,15 @@ impl settings::Settings for TelemetrySettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
Ok(Self {
|
||||
diagnostics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.diagnostics)
|
||||
.unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
metrics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.metrics)
|
||||
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
|
||||
})
|
||||
@@ -1031,7 +1023,7 @@ impl Client {
|
||||
&self,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
release_channel: Option<ReleaseChannel>,
|
||||
) -> impl Future<Output = Result<url::Url>> {
|
||||
) -> impl Future<Output = Result<Url>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
let url_override = self.rpc_url.read().clone();
|
||||
|
||||
@@ -1125,7 +1117,7 @@ impl Client {
|
||||
// for us from the RPC URL.
|
||||
//
|
||||
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
|
||||
let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
|
||||
let mut request = rpc_url.into_client_request()?;
|
||||
|
||||
// We then modify the request to add our desired headers.
|
||||
let request_headers = request.headers_mut();
|
||||
@@ -1164,7 +1156,6 @@ impl Client {
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let (stream, _) =
|
||||
async_tungstenite::async_tls::client_async_tls_with_connector(
|
||||
request,
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
//! Contains helper functions for constructing URLs to various Zed-related pages.
|
||||
//!
|
||||
//! These URLs will adapt to the configured server URL in order to construct
|
||||
//! links appropriate for the environment (e.g., by linking to a local copy of
|
||||
//! zed.dev in development).
|
||||
|
||||
use gpui::AppContext;
|
||||
use settings::Settings;
|
||||
|
||||
use crate::ClientSettings;
|
||||
|
||||
fn server_url(cx: &AppContext) -> &str {
|
||||
&ClientSettings::get_global(cx).server_url
|
||||
}
|
||||
|
||||
/// Returns the URL to the account page on zed.dev.
|
||||
pub fn account_url(cx: &AppContext) -> String {
|
||||
format!("{server_url}/account", server_url = server_url(cx))
|
||||
}
|
||||
@@ -32,12 +32,12 @@ clickhouse.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc_http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
@@ -48,7 +48,6 @@ prometheus = "0.13"
|
||||
prost.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
reqwest_client.workspace = true
|
||||
rpc.workspace = true
|
||||
rustc-demangle.workspace = true
|
||||
scrypt = "0.11"
|
||||
@@ -67,7 +66,7 @@ telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tokio.workspace = true
|
||||
toml.workspace = true
|
||||
tower = "0.4"
|
||||
tower-http = { workspace = true, features = ["trace"] }
|
||||
|
||||
@@ -199,12 +199,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: slack
|
||||
key: panics_webhook
|
||||
- name: STRIPE_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: stripe
|
||||
key: api_key
|
||||
optional: true
|
||||
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
|
||||
value: "1000"
|
||||
- name: SUPERMAVEN_ADMIN_API_KEY
|
||||
|
||||
@@ -422,15 +422,6 @@ CREATE TABLE dev_server_projects (
|
||||
paths TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
create table if not exists billing_preferences (
|
||||
id serial primary key,
|
||||
created_at timestamp without time zone not null default now(),
|
||||
user_id integer not null references users(id) on delete cascade,
|
||||
max_monthly_llm_usage_spending_in_cents integer not null
|
||||
);
|
||||
|
||||
create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id);
|
||||
@@ -1,12 +0,0 @@
|
||||
create table billing_events (
|
||||
id serial primary key,
|
||||
idempotency_key uuid not null default gen_random_uuid(),
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
input_tokens bigint not null default 0,
|
||||
input_cache_creation_tokens bigint not null default 0,
|
||||
input_cache_read_tokens bigint not null default 0,
|
||||
output_tokens bigint not null default 0
|
||||
);
|
||||
|
||||
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
|
||||
@@ -1,3 +1,7 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use axum::{
|
||||
extract::{self, Query},
|
||||
@@ -5,43 +9,32 @@ use axum::{
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use collections::HashSet;
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use stripe::{
|
||||
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
||||
Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::{
|
||||
db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
use crate::{
|
||||
db::{billing_subscription::StripeSubscriptionStatus, UserId},
|
||||
llm::db::LlmDatabase,
|
||||
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
|
||||
use crate::db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
};
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS;
|
||||
use crate::rpc::ResultExt as _;
|
||||
use crate::{AppState, Error, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/billing/preferences",
|
||||
get(get_billing_preferences).put(update_billing_preferences),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions",
|
||||
get(list_billing_subscriptions).post(create_billing_subscription),
|
||||
@@ -50,86 +43,6 @@ pub fn router() -> Router {
|
||||
"/billing/subscriptions/manage",
|
||||
post(manage_billing_subscription),
|
||||
)
|
||||
.route("/billing/monthly_spend", get(get_monthly_spend))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetBillingPreferencesParams {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn get_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetBillingPreferencesParams>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(params.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let preferences = app.db.get_billing_preferences(user.id).await?;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpdateBillingPreferencesBody {
|
||||
github_user_id: i32,
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn update_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let billing_preferences =
|
||||
if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
|
||||
app.db
|
||||
.update_billing_preferences(
|
||||
user.id,
|
||||
&UpdateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
body.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
},
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
app.db
|
||||
.create_billing_preferences(
|
||||
user.id,
|
||||
&crate::db::CreateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: body
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -204,22 +117,12 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
let Some((stripe_client, stripe_price_id)) = app
|
||||
.stripe_client
|
||||
.clone()
|
||||
.zip(app.config.stripe_llm_usage_price_id.clone())
|
||||
else {
|
||||
log::error!("failed to retrieve Stripe client or price ID");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
@@ -243,14 +146,26 @@ async fn create_billing_subscription(
|
||||
customer.id
|
||||
};
|
||||
|
||||
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
let checkout_session_url = stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?;
|
||||
let checkout_session = {
|
||||
let mut params = CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(user.github_login.as_str());
|
||||
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
|
||||
price: Some(stripe_price_id.to_string()),
|
||||
quantity: Some(0),
|
||||
..Default::default()
|
||||
}]);
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
params.success_url = Some(&success_url);
|
||||
|
||||
CheckoutSession::create(&stripe_client, params).await?
|
||||
};
|
||||
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
checkout_session_url,
|
||||
checkout_session_url: checkout_session
|
||||
.url
|
||||
.ok_or_else(|| anyhow!("no checkout session URL"))?,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -405,7 +320,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
||||
|
||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
@@ -416,9 +331,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client)
|
||||
.await
|
||||
.log_err();
|
||||
poll_stripe_events(&app, &stripe_client).await.log_err();
|
||||
|
||||
executor.sleep(POLL_EVENTS_INTERVAL).await;
|
||||
}
|
||||
@@ -428,7 +341,6 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
||||
|
||||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
@@ -453,28 +365,29 @@ async fn poll_stripe_events(
|
||||
let mut pages_of_already_processed_events = 0;
|
||||
let mut unprocessed_events = Vec::new();
|
||||
|
||||
log::info!(
|
||||
"Stripe events: starting retrieval for {}",
|
||||
event_types.join(", ")
|
||||
);
|
||||
let mut params = ListEvents::new();
|
||||
params.types = Some(event_types.clone());
|
||||
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
||||
|
||||
let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
|
||||
.await?
|
||||
.paginate(params);
|
||||
|
||||
loop {
|
||||
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP {
|
||||
log::info!("saw {pages_of_already_processed_events} pages of already-processed events: stopping event retrieval");
|
||||
break;
|
||||
}
|
||||
|
||||
log::info!("retrieving events from Stripe: {}", event_types.join(", "));
|
||||
|
||||
let mut params = ListEvents::new();
|
||||
params.types = Some(event_types.clone());
|
||||
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
||||
|
||||
let events = stripe::Event::list(stripe_client, ¶ms).await?;
|
||||
|
||||
let processed_event_ids = {
|
||||
let event_ids = event_pages
|
||||
.page
|
||||
let event_ids = &events
|
||||
.data
|
||||
.iter()
|
||||
.map(|event| event.id.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
app.db
|
||||
.get_processed_stripe_events_by_event_ids(&event_ids)
|
||||
.get_processed_stripe_events_by_event_ids(event_ids)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|event| event.stripe_event_id)
|
||||
@@ -482,13 +395,13 @@ async fn poll_stripe_events(
|
||||
};
|
||||
|
||||
let mut processed_events_in_page = 0;
|
||||
let events_in_page = event_pages.page.data.len();
|
||||
for event in &event_pages.page.data {
|
||||
let events_in_page = events.data.len();
|
||||
for event in events.data {
|
||||
if processed_event_ids.contains(&event.id.to_string()) {
|
||||
processed_events_in_page += 1;
|
||||
log::debug!("Stripe events: already processed '{}', skipping", event.id);
|
||||
log::debug!("Stripe event {} already processed: skipping", event.id);
|
||||
} else {
|
||||
unprocessed_events.push(event.clone());
|
||||
unprocessed_events.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -496,21 +409,15 @@ async fn poll_stripe_events(
|
||||
pages_of_already_processed_events += 1;
|
||||
}
|
||||
|
||||
if event_pages.page.has_more {
|
||||
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
|
||||
{
|
||||
log::info!("Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events");
|
||||
break;
|
||||
} else {
|
||||
log::info!("Stripe events: retrieving next page");
|
||||
event_pages = event_pages.next(&stripe_client).await?;
|
||||
}
|
||||
} else {
|
||||
if !events.has_more {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
|
||||
log::info!(
|
||||
"unprocessed events from Stripe: {}",
|
||||
unprocessed_events.len()
|
||||
);
|
||||
|
||||
// Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
|
||||
unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
|
||||
@@ -526,12 +433,12 @@ async fn poll_stripe_events(
|
||||
// If the event has happened too far in the past, we don't want to
|
||||
// process it and risk overwriting other more-recent updates.
|
||||
//
|
||||
// 1 day was chosen arbitrarily. This could be made longer or shorter.
|
||||
let one_day = Duration::from_secs(24 * 60 * 60);
|
||||
let a_day_ago = Utc::now() - one_day;
|
||||
if a_day_ago.timestamp() > event.created {
|
||||
// 1 hour was chosen arbitrarily. This could be made longer or shorter.
|
||||
let one_hour = Duration::from_secs(60 * 60);
|
||||
let an_hour_ago = Utc::now() - one_hour;
|
||||
if an_hour_ago.timestamp() > event.created {
|
||||
log::info!(
|
||||
"Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
|
||||
"Stripe event {} is more than {one_hour:?} old, marking as processed",
|
||||
event_id
|
||||
);
|
||||
app.db
|
||||
@@ -550,7 +457,7 @@ async fn poll_stripe_events(
|
||||
| EventType::CustomerSubscriptionPaused
|
||||
| EventType::CustomerSubscriptionResumed
|
||||
| EventType::CustomerSubscriptionDeleted => {
|
||||
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
|
||||
handle_customer_subscription_event(app, stripe_client, event).await
|
||||
}
|
||||
_ => Ok(()),
|
||||
};
|
||||
@@ -618,7 +525,6 @@ async fn handle_customer_event(
|
||||
|
||||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -664,52 +570,9 @@ async fn handle_customer_subscription_event(
|
||||
.await?;
|
||||
}
|
||||
|
||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||
// to either grant/revoke access.
|
||||
rpc_server
|
||||
.refresh_llm_tokens_for_user(billing_customer.user_id)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetMonthlySpendParams {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GetMonthlySpendResponse {
|
||||
monthly_spend_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn get_monthly_spend(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetMonthlySpendParams>,
|
||||
) -> Result<Json<GetMonthlySpendResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(params.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
return Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"LLM database not available".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let monthly_spend = llm_db
|
||||
.get_user_spending_for_month(user.id, Utc::now())
|
||||
.await?
|
||||
.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
||||
|
||||
Ok(Json(GetMonthlySpendResponse {
|
||||
monthly_spend_in_cents: monthly_spend.0 as i32,
|
||||
}))
|
||||
}
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
match value {
|
||||
@@ -772,15 +635,15 @@ async fn find_or_create_billing_customer(
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
|
||||
log::warn!("failed to retrieve Stripe LLM usage price ID");
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -789,10 +652,15 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.context("failed to sync LLM usage to Stripe")
|
||||
.trace_err();
|
||||
sync_with_stripe(
|
||||
&app,
|
||||
&llm_db,
|
||||
&stripe_client,
|
||||
stripe_llm_usage_price_id.clone(),
|
||||
)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
@@ -801,44 +669,60 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
|
||||
async fn sync_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: Arc<str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
.map(|(event, _)| event.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||
let subscriptions = app.db.get_active_billing_subscriptions().await?;
|
||||
|
||||
for (event, model) in events {
|
||||
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||
stripe_subscriptions.get(&event.user_id)
|
||||
else {
|
||||
tracing::warn!(
|
||||
user_id = event.user_id.0,
|
||||
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||
.stripe_subscription_id
|
||||
.parse()
|
||||
.context("failed to parse stripe subscription id from db")?;
|
||||
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||
.stripe_customer_id
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
for (customer, subscription) in subscriptions {
|
||||
update_stripe_subscription(
|
||||
llm_db,
|
||||
stripe_client,
|
||||
&stripe_llm_usage_price_id,
|
||||
customer,
|
||||
subscription,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_stripe_subscription(
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: &Arc<str>,
|
||||
customer: billing_customer::Model,
|
||||
subscription: billing_subscription::Model,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let monthly_spending = llm_db
|
||||
.get_user_spending_for_month(customer.user_id, Utc::now())
|
||||
.await?;
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
let monthly_spending_over_free_tier =
|
||||
monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS);
|
||||
|
||||
let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil();
|
||||
Subscription::update(
|
||||
stripe_client,
|
||||
&subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||
// TODO: Do we need to send up the `id` if a subscription item
|
||||
// with this price already exists, or will Stripe take care of
|
||||
// it?
|
||||
id: None,
|
||||
price: Some(stripe_llm_usage_price_id.to_string()),
|
||||
quantity: Some(new_quantity as u64),
|
||||
..Default::default()
|
||||
}]),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -670,6 +670,7 @@ pub struct EditorEventRow {
|
||||
time: i64,
|
||||
copilot_enabled: bool,
|
||||
copilot_enabled_for_language: bool,
|
||||
historical_event: bool,
|
||||
architecture: String,
|
||||
is_staff: Option<bool>,
|
||||
major: Option<i32>,
|
||||
@@ -717,6 +718,7 @@ impl EditorEventRow {
|
||||
country_code: country_code.unwrap_or("XX".to_string()),
|
||||
region_code: "".to_string(),
|
||||
city: "".to_string(),
|
||||
historical_event: false,
|
||||
is_via_ssh: event.is_via_ssh,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
/// A number of cents.
|
||||
#[derive(
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Hash,
|
||||
Clone,
|
||||
Copy,
|
||||
derive_more::Add,
|
||||
derive_more::AddAssign,
|
||||
derive_more::Sub,
|
||||
derive_more::SubAssign,
|
||||
)]
|
||||
pub struct Cents(pub u32);
|
||||
|
||||
impl Cents {
|
||||
pub const ZERO: Self = Self(0);
|
||||
|
||||
pub const fn new(cents: u32) -> Self {
|
||||
Self(cents)
|
||||
}
|
||||
|
||||
pub const fn from_dollars(dollars: u32) -> Self {
|
||||
Self(dollars * 100)
|
||||
}
|
||||
|
||||
pub fn saturating_sub(self, other: Cents) -> Self {
|
||||
Self(self.0.saturating_sub(other.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cents_new() {
|
||||
assert_eq!(Cents::new(50), Cents(50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_from_dollars() {
|
||||
assert_eq!(Cents::from_dollars(1), Cents(100));
|
||||
assert_eq!(Cents::from_dollars(5), Cents(500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_zero() {
|
||||
assert_eq!(Cents::ZERO, Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add() {
|
||||
assert_eq!(Cents(50) + Cents(30), Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add_assign() {
|
||||
let mut cents = Cents(50);
|
||||
cents += Cents(30);
|
||||
assert_eq!(cents, Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_saturating_sub() {
|
||||
assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20));
|
||||
assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_ordering() {
|
||||
assert!(Cents(50) > Cents(30));
|
||||
assert!(Cents(30) < Cents(50));
|
||||
assert_eq!(Cents(50), Cents(50));
|
||||
}
|
||||
}
|
||||
@@ -42,9 +42,6 @@ pub use tests::TestDb;
|
||||
|
||||
pub use ids::*;
|
||||
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
|
||||
pub use queries::billing_preferences::{
|
||||
CreateBillingPreferencesParams, UpdateBillingPreferencesParams,
|
||||
};
|
||||
pub use queries::billing_subscriptions::{
|
||||
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
|
||||
@@ -72,7 +72,6 @@ macro_rules! id_type {
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BillingCustomerId);
|
||||
id_type!(BillingSubscriptionId);
|
||||
id_type!(BillingPreferencesId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
|
||||
@@ -2,7 +2,6 @@ use super::*;
|
||||
|
||||
pub mod access_tokens;
|
||||
pub mod billing_customers;
|
||||
pub mod billing_preferences;
|
||||
pub mod billing_subscriptions;
|
||||
pub mod buffers;
|
||||
pub mod channels;
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Returns the billing preferences for the given user, if they exist.
|
||||
pub async fn get_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Creates new billing preferences for the given user.
|
||||
pub async fn create_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &CreateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
params.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the billing preferences for the given user.
|
||||
pub async fn update_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &UpdateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::update_many()
|
||||
.set(billing_preference::ActiveModel {
|
||||
max_monthly_llm_usage_spending_in_cents: params
|
||||
.max_monthly_llm_usage_spending_in_cents
|
||||
.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("billing preferences not found"))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -114,31 +114,23 @@ impl Database {
|
||||
|
||||
pub async fn get_active_billing_subscriptions(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut result = Vec::new();
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
result.push((customer, subscription));
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -838,7 +838,6 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect(),
|
||||
dev_server_project_id: project.dev_server_project_id,
|
||||
|
||||
@@ -718,7 +718,6 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod access_token;
|
||||
pub mod billing_customer;
|
||||
pub mod billing_preference;
|
||||
pub mod billing_subscription;
|
||||
pub mod buffer;
|
||||
pub mod buffer_operation;
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
use crate::db::{BillingPreferencesId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_preferences")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingPreferencesId,
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
mod cents;
|
||||
pub mod clickhouse;
|
||||
pub mod db;
|
||||
pub mod env;
|
||||
@@ -10,7 +9,6 @@ pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -22,17 +20,13 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
pub use cents::*;
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
@@ -180,6 +174,7 @@ pub struct Config {
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub stripe_llm_usage_price_id: Option<Arc<str>>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -199,7 +194,7 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn is_llm_billing_enabled(&self) -> bool {
|
||||
self.stripe_api_key.is_some()
|
||||
self.stripe_llm_usage_price_id.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -240,6 +235,7 @@ impl Config {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_llm_usage_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
}
|
||||
@@ -272,11 +268,9 @@ impl ServiceMode {
|
||||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub clickhouse_client: Option<::clickhouse::Client>,
|
||||
@@ -290,20 +284,6 @@ impl AppState {
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
.llm_database_url
|
||||
.clone()
|
||||
.zip(config.llm_database_max_connections)
|
||||
{
|
||||
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||
llm_db_options.max_connections(llm_database_max_connections);
|
||||
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||
llm_db.initialize().await?;
|
||||
Some(Arc::new(llm_db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let live_kit_client = if let Some(((server, key), secret)) = config
|
||||
.live_kit_server
|
||||
.as_ref()
|
||||
@@ -320,16 +300,11 @@ impl AppState {
|
||||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
live_kit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
clickhouse_client: config
|
||||
@@ -342,11 +317,12 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
let api_key = config
|
||||
.stripe_api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
|
||||
|
||||
Ok(stripe::Client::new(api_key))
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ mod telemetry;
|
||||
mod token;
|
||||
|
||||
use crate::{
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
|
||||
Config, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
@@ -20,14 +20,13 @@ use axum::{
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use isahc_http_client::IsahcHttpClient;
|
||||
use rpc::ListModelsResponse;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
@@ -44,7 +43,7 @@ pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub http_client: ReqwestClient,
|
||||
pub http_client: IsahcHttpClient,
|
||||
pub clickhouse_client: Option<clickhouse::Client>,
|
||||
active_user_count_by_model:
|
||||
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
|
||||
@@ -70,8 +69,11 @@ impl LlmState {
|
||||
let db = Arc::new(db);
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client =
|
||||
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
|
||||
let http_client = IsahcHttpClient::builder()
|
||||
.default_header("User-Agent", user_agent)
|
||||
.build()
|
||||
.map(IsahcHttpClient::from)
|
||||
.context("failed to construct http client")?;
|
||||
|
||||
let this = Self {
|
||||
executor,
|
||||
@@ -416,7 +418,10 @@ async fn perform_completion(
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
tokens: TokenUsage::default(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
@@ -433,15 +438,13 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
/// The maximum monthly spending an individual user can reach before they have to pay.
|
||||
pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100;
|
||||
|
||||
/// The default value to use for maximum spend per month if the user did not
|
||||
/// explicitly set a maximum spend.
|
||||
/// The maximum lifetime spending an individual user can reach before being cut off.
|
||||
///
|
||||
/// Used to prevent surprise bills.
|
||||
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
|
||||
/// Represented in cents.
|
||||
const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100;
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
@@ -461,31 +464,24 @@ async fn check_usage_limit(
|
||||
.await?;
|
||||
|
||||
if state.config.is_llm_billing_enabled() {
|
||||
if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
|
||||
if !claims.has_llm_subscription {
|
||||
if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS {
|
||||
if !claims.has_llm_subscription.unwrap_or(false) {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if (usage.spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT)
|
||||
>= Cents(claims.max_monthly_spend_in_cents)
|
||||
{
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Remove this once we've rolled out monthly spending limits.
|
||||
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
|
||||
return Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let active_users = state.get_active_user_count(provider, model_name).await?;
|
||||
|
||||
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
|
||||
@@ -597,7 +593,10 @@ struct TokenCountingStream<S> {
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
tokens: TokenUsage,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
@@ -611,10 +610,10 @@ where
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
self.input_tokens += chunk.input_tokens;
|
||||
self.output_tokens += chunk.output_tokens;
|
||||
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
|
||||
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
@@ -627,11 +626,13 @@ where
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let is_llm_billing_enabled = state.config.is_llm_billing_enabled();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let tokens = self.tokens;
|
||||
let input_token_count = self.input_tokens;
|
||||
let output_token_count = self.output_tokens;
|
||||
let cache_creation_input_token_count = self.cache_creation_input_tokens;
|
||||
let cache_read_input_token_count = self.cache_read_input_tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
@@ -640,16 +641,10 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
tokens,
|
||||
// We're passing `false` here if LLM billing is not enabled
|
||||
// so that we don't write any records to the
|
||||
// `billing_events` table until we're ready to bill users.
|
||||
if is_llm_billing_enabled {
|
||||
claims.has_llm_subscription
|
||||
} else {
|
||||
false
|
||||
},
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
input_token_count,
|
||||
cache_creation_input_token_count,
|
||||
cache_read_input_token_count,
|
||||
output_token_count,
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
@@ -679,25 +674,24 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
},
|
||||
model,
|
||||
provider: provider.to_string(),
|
||||
input_token_count: tokens.input as u64,
|
||||
cache_creation_input_token_count: tokens.input_cache_creation as u64,
|
||||
cache_read_input_token_count: tokens.input_cache_read as u64,
|
||||
output_token_count: tokens.output as u64,
|
||||
input_token_count: input_token_count as u64,
|
||||
cache_creation_input_token_count: cache_creation_input_token_count
|
||||
as u64,
|
||||
cache_read_input_token_count: cache_read_input_token_count as u64,
|
||||
output_token_count: output_token_count as u64,
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
input_tokens_this_month: usage.tokens_this_month.input as u64,
|
||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
||||
cache_creation_input_tokens_this_month: usage
|
||||
.tokens_this_month
|
||||
.input_cache_creation
|
||||
.cache_creation_input_tokens_this_month
|
||||
as u64,
|
||||
cache_read_input_tokens_this_month: usage
|
||||
.tokens_this_month
|
||||
.input_cache_read
|
||||
.cache_read_input_tokens_this_month
|
||||
as u64,
|
||||
output_tokens_this_month: usage.tokens_this_month.output as u64,
|
||||
spending_this_month: usage.spending_this_month.0 as u64,
|
||||
lifetime_spending: usage.lifetime_spending.0 as u64,
|
||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
||||
spending_this_month: usage.spending_this_month as u64,
|
||||
lifetime_spending: usage.lifetime_spending as u64,
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||
pub use queries::usages::ActiveUserCount;
|
||||
use sea_orm::prelude::*;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::{
|
||||
|
||||
@@ -3,9 +3,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(BillingEventId);
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use super::*;
|
||||
use crate::Result;
|
||||
use anyhow::Context as _;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let events_with_models = billing_event::Entity::find()
|
||||
.find_also_related(model::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
events_with_models
|
||||
.into_iter()
|
||||
.map(|(event, model)| {
|
||||
let model =
|
||||
model.context("could not find model associated with billing event")?;
|
||||
Ok((event, model))
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::llm::Cents;
|
||||
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::db::UserId;
|
||||
use chrono::{Datelike, Duration};
|
||||
use futures::StreamExt as _;
|
||||
use rpc::LanguageModelProvider;
|
||||
@@ -9,28 +8,17 @@ use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
pub input_cache_read: usize,
|
||||
pub output: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total(&self) -> usize {
|
||||
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub tokens_this_month: TokenUsage,
|
||||
pub spending_this_month: Cents,
|
||||
pub lifetime_spending: Cents,
|
||||
pub input_tokens_this_month: usize,
|
||||
pub cache_creation_input_tokens_this_month: usize,
|
||||
pub cache_read_input_tokens_this_month: usize,
|
||||
pub output_tokens_this_month: usize,
|
||||
pub spending_this_month: usize,
|
||||
pub lifetime_spending: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
@@ -156,7 +144,7 @@ impl LlmDatabase {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Cents> {
|
||||
) -> Result<usize> {
|
||||
self.transaction(|tx| async move {
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
@@ -170,7 +158,7 @@ impl LlmDatabase {
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
let mut monthly_spending = Cents::ZERO;
|
||||
let mut monthly_spending_in_cents = 0;
|
||||
|
||||
while let Some(usage) = monthly_usages.next().await {
|
||||
let usage = usage?;
|
||||
@@ -178,7 +166,7 @@ impl LlmDatabase {
|
||||
continue;
|
||||
};
|
||||
|
||||
monthly_spending += calculate_spending(
|
||||
monthly_spending_in_cents += calculate_spending(
|
||||
model,
|
||||
usage.input_tokens as usize,
|
||||
usage.cache_creation_input_tokens as usize,
|
||||
@@ -187,7 +175,7 @@ impl LlmDatabase {
|
||||
);
|
||||
}
|
||||
|
||||
Ok(monthly_spending)
|
||||
Ok(monthly_spending_in_cents)
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -250,7 +238,7 @@ impl LlmDatabase {
|
||||
monthly_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
0
|
||||
};
|
||||
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
|
||||
calculate_spending(
|
||||
@@ -261,27 +249,25 @@ impl LlmDatabase {
|
||||
lifetime_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
0
|
||||
};
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
input_cache_creation: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
input_cache_read: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
},
|
||||
input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
cache_creation_input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
cache_read_input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -296,9 +282,10 @@ impl LlmDatabase {
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
tokens: TokenUsage,
|
||||
has_llm_subscription: bool,
|
||||
max_monthly_spend: Cents,
|
||||
input_token_count: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
output_token_count: usize,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -325,6 +312,10 @@ impl LlmDatabase {
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let total_token_count = input_token_count
|
||||
+ cache_read_input_tokens
|
||||
+ cache_creation_input_tokens
|
||||
+ output_token_count;
|
||||
let tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
@@ -333,7 +324,7 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
tokens.total(),
|
||||
total_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
@@ -345,7 +336,7 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
tokens.total(),
|
||||
total_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
@@ -369,14 +360,18 @@ impl LlmDatabase {
|
||||
Some(usage) => {
|
||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
@@ -388,12 +383,12 @@ impl LlmDatabase {
|
||||
model_id: ActiveValue::set(model.id),
|
||||
month: ActiveValue::set(month),
|
||||
year: ActiveValue::set(year),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
cache_creation_input_tokens as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
@@ -409,27 +404,6 @@ impl LlmDatabase {
|
||||
monthly_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
if !is_staff
|
||||
&& spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
|
||||
&& has_llm_subscription
|
||||
&& (spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT) <= max_monthly_spend
|
||||
{
|
||||
billing_event::ActiveModel {
|
||||
id: ActiveValue::not_set(),
|
||||
idempotency_key: ActiveValue::not_set(),
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_cache_creation_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
@@ -444,14 +418,18 @@ impl LlmDatabase {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
@@ -461,12 +439,12 @@ impl LlmDatabase {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
cache_creation_input_tokens as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
@@ -486,12 +464,11 @@ impl LlmDatabase {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage.input_tokens as usize,
|
||||
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||
output: monthly_usage.output_tokens as usize,
|
||||
},
|
||||
input_tokens_this_month: monthly_usage.input_tokens as usize,
|
||||
cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
|
||||
as usize,
|
||||
cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
|
||||
output_tokens_this_month: monthly_usage.output_tokens as usize,
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -660,7 +637,7 @@ fn calculate_spending(
|
||||
cache_creation_input_tokens_this_month: usize,
|
||||
cache_read_input_tokens_this_month: usize,
|
||||
output_tokens_this_month: usize,
|
||||
) -> Cents {
|
||||
) -> usize {
|
||||
let input_token_cost =
|
||||
input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000;
|
||||
let cache_creation_input_token_cost = cache_creation_input_tokens_this_month
|
||||
@@ -671,11 +648,10 @@ fn calculate_spending(
|
||||
/ 1_000_000;
|
||||
let output_token_cost =
|
||||
output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000;
|
||||
let spending = input_token_cost
|
||||
input_token_cost
|
||||
+ cache_creation_input_token_cost
|
||||
+ cache_read_input_token_cost
|
||||
+ output_token_cost;
|
||||
Cents::new(spending as u32)
|
||||
+ output_token_cost
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod billing_event;
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{BillingEventId, ModelId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingEventId,
|
||||
pub idempotency_key: Uuid,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub input_cache_creation_tokens: i64,
|
||||
pub input_cache_read_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -29,8 +29,6 @@ pub enum Relation {
|
||||
Provider,
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
#[sea_orm(has_many = "super::billing_event::Entity")]
|
||||
BillingEvents,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
@@ -45,10 +43,4 @@ impl Related<super::usage::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_event::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingEvents.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
mod billing_tests;
|
||||
mod provider_tests;
|
||||
mod usage_tests;
|
||||
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::{
|
||||
db::{queries::providers::ModelParams, LlmDatabase, TokenUsage},
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
},
|
||||
test_llm_db, Cents,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(
|
||||
test_billing_limit_exceeded,
|
||||
test_billing_limit_exceeded_postgres
|
||||
);
|
||||
|
||||
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "fake-claude-limerick";
|
||||
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
|
||||
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
|
||||
|
||||
// Initialize the database and insert the model
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
|
||||
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Set a fixed datetime for consistent testing
|
||||
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let max_monthly_spend = Cents::from_dollars(11);
|
||||
|
||||
// Record usage that brings us close to the limit but doesn't exceed it
|
||||
// Let's say we use $10.50 worth of tokens
|
||||
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
|
||||
let usage = TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// Verify that before we record any usage, there are 0 billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 0);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the recorded usage and spending
|
||||
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
// Verify that we exceeded the free tier usage
|
||||
assert_eq!(recorded_usage.spending_this_month, Cents::new(1050));
|
||||
assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
||||
|
||||
// Verify that there is one `billing_event` record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
let (billing_event, _model) = &billing_events[0];
|
||||
assert_eq!(billing_event.user_id, user_id);
|
||||
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
|
||||
assert_eq!(billing_event.input_cache_creation_tokens, 0);
|
||||
assert_eq!(billing_event.input_cache_read_tokens, 0);
|
||||
assert_eq!(billing_event.output_tokens, 0);
|
||||
|
||||
// Record usage that puts us at $20.50
|
||||
let usage_2 = TokenUsage {
|
||||
input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_2,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2050));
|
||||
|
||||
// Verify that there are now two billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
|
||||
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit
|
||||
let usage_exceeding = TokenUsage {
|
||||
input: tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// This should still create a billing event as it's the first request that exceeds the limit
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_exceeding,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(updated_usage.spending_this_month, Cents::new(2150));
|
||||
|
||||
// Verify that we never exceed the user max spending for the user
|
||||
// and avoid charging them.
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 2);
|
||||
}
|
||||
@@ -2,9 +2,9 @@ use crate::{
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase, TokenUsage,
|
||||
LlmDatabase,
|
||||
},
|
||||
test_llm_db, Cents,
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -36,42 +36,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -80,14 +52,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 3000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -99,35 +69,19 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 3000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -136,14 +90,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
tokens_this_day: 6000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 6000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -156,34 +108,18 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 6000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 4000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -192,14 +128,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
tokens_this_day: 9000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 10000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 10000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -209,23 +143,9 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
.with_timezone(&Utc);
|
||||
|
||||
// Test cache creation input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -234,35 +154,19 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 1500,
|
||||
tokens_this_day: 1500,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 1000,
|
||||
cache_creation_input_tokens_this_month: 500,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
// Test cache read input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -271,14 +175,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 2800,
|
||||
tokens_this_day: 2800,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 2000,
|
||||
cache_creation_input_tokens_this_month: 500,
|
||||
cache_read_input_tokens_this_month: 300,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::{
|
||||
db::{billing_preference, UserId},
|
||||
Config,
|
||||
};
|
||||
use crate::{db::UserId, Config};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
@@ -20,20 +16,22 @@ pub struct LlmTokenClaims {
|
||||
pub github_user_login: String,
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub has_llm_subscription: bool,
|
||||
pub max_monthly_spend_in_cents: u32,
|
||||
// This field is temporarily optional so it can be added
|
||||
// in a backwards-compatible way. We can make it required
|
||||
// once all of the LLM tokens have cycled (~1 hour after
|
||||
// this change has been deployed).
|
||||
#[serde(default)]
|
||||
pub has_llm_subscription: Option<bool>,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn create(
|
||||
user_id: UserId,
|
||||
github_user_login: String,
|
||||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
has_llm_closed_beta_feature_flag: bool,
|
||||
has_llm_subscription: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
@@ -53,11 +51,7 @@ impl LlmTokenClaims {
|
||||
github_user_login,
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
has_llm_subscription,
|
||||
max_monthly_spend_in_cents: billing_preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents as u32
|
||||
}),
|
||||
has_llm_subscription: Some(has_llm_subscription),
|
||||
plan,
|
||||
};
|
||||
|
||||
|
||||
@@ -111,13 +111,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
if let Some(stripe_billing) = state.stripe_billing.clone() {
|
||||
let executor = state.executor.clone();
|
||||
executor.spawn_detached(async move {
|
||||
stripe_billing.initialize().await.trace_err();
|
||||
});
|
||||
}
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
@@ -132,8 +125,6 @@ async fn main() -> Result<()> {
|
||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::api::routes(rpc_server.clone()))
|
||||
.merge(collab::rpc::routes(rpc_server.clone()));
|
||||
@@ -142,6 +133,7 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
if mode.is_api() {
|
||||
poll_stripe_events_periodically(state.clone());
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
@@ -163,9 +155,8 @@ async fn main() -> Result<()> {
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_usage_with_stripe_periodically(state.clone());
|
||||
if let Some(llm_db) = llm_db {
|
||||
sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
|
||||
}
|
||||
|
||||
app = app
|
||||
|
||||
@@ -36,8 +36,8 @@ use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use http_client::HttpClient;
|
||||
use isahc_http_client::IsahcHttpClient;
|
||||
use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use sha2::Digest;
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
|
||||
@@ -469,6 +469,9 @@ impl Server {
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskContextForLocation>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskTemplates>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_read_only_project_request::<proto::GetHover>,
|
||||
))
|
||||
@@ -961,8 +964,8 @@ impl Server {
|
||||
tracing::info!("connection opened");
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = match ReqwestClient::user_agent(&user_agent) {
|
||||
Ok(http_client) => Arc::new(http_client),
|
||||
let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
|
||||
Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to create HTTP client");
|
||||
return;
|
||||
@@ -1218,15 +1221,6 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, proto::RefreshLlmToken {})
|
||||
.trace_err();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
||||
ServerSnapshot {
|
||||
connection_pool: ConnectionPoolGuard {
|
||||
@@ -4926,14 +4920,10 @@ async fn get_llm_api_token(
|
||||
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
user.id,
|
||||
user.github_login.clone(),
|
||||
session.is_staff(),
|
||||
billing_preferences,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
session.has_llm_subscription(&db).await?,
|
||||
session.current_plan(&db).await?,
|
||||
|
||||
@@ -1,479 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{llm, Cents, Result};
|
||||
use anyhow::Context;
|
||||
use chrono::{Datelike, Utc};
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<stripe::Client>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
}
|
||||
|
||||
pub struct StripeModel {
|
||||
input_tokens_price: StripeBillingPrice,
|
||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||
input_cache_read_tokens_price: StripeBillingPrice,
|
||||
output_tokens_price: StripeBillingPrice,
|
||||
}
|
||||
|
||||
struct StripeBillingPrice {
|
||||
id: stripe::PriceId,
|
||||
meter_event_name: String,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
log::info!("StripeBilling: initializing");
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let (meters, prices) = futures::try_join!(
|
||||
StripeMeter::list(&self.client),
|
||||
stripe::Price::list(
|
||||
&self.client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
)?;
|
||||
|
||||
for meter in meters.data {
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
|
||||
for price in prices.data {
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("StripeBilling: initialized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
&format!("{} (Input Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_creation_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_read_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/output_tokens", model.id),
|
||||
&format!("{} (Output Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_output_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
Ok(StripeModel {
|
||||
input_tokens_price,
|
||||
input_cache_creation_tokens_price,
|
||||
input_cache_read_tokens_price,
|
||||
output_tokens_price,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_insert_price(
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
price_per_million_tokens: Cents,
|
||||
) -> Result<StripeBillingPrice> {
|
||||
// Fast code path when the meter and the price already exist.
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
return Ok(StripeBillingPrice {
|
||||
id: price_id.clone(),
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
|
||||
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
active: Some(true),
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.price_ids_by_meter_id
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
|
||||
Ok(StripeBillingPrice {
|
||||
id: price_id,
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_model(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
model: &StripeModel,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
|
||||
{
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_read_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.output_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !items.is_empty() {
|
||||
items.extend(subscription.items.data.iter().map(|item| {
|
||||
stripe::UpdateSubscriptionItems {
|
||||
id: Some(item.id.to_string()),
|
||||
..Default::default()
|
||||
}
|
||||
}));
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(items),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
model: &StripeModel,
|
||||
event: &llm::db::billing_event::Model,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
|
||||
if event.input_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_creation_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_creation_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_read_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_read_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_read_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.output_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("output_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.output_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.output_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
model: &StripeModel,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let first_of_next_month = Utc::now()
|
||||
.checked_add_months(chrono::Months::new(1))
|
||||
.unwrap()
|
||||
.with_day(1)
|
||||
.unwrap();
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
|
||||
billing_cycle_anchor: Some(first_of_next_month.timestamp()),
|
||||
..Default::default()
|
||||
});
|
||||
params.line_items = Some(
|
||||
[
|
||||
&model.input_tokens_price.id,
|
||||
&model.input_cache_creation_tokens_price.id,
|
||||
&model.input_cache_read_tokens_price.id,
|
||||
&model.output_tokens_price.id,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(price_id.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DefaultAggregation {
|
||||
formula: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterParams<'a> {
|
||||
default_aggregation: DefaultAggregation,
|
||||
display_name: String,
|
||||
event_name: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct StripeMeter {
|
||||
id: String,
|
||||
event_name: String,
|
||||
}
|
||||
|
||||
impl StripeMeter {
|
||||
pub fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterParams,
|
||||
) -> stripe::Response<Self> {
|
||||
client.post_form("/billing/meters", params)
|
||||
}
|
||||
|
||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
client.get_query("/billing/meters", Params { limit: Some(100) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StripeMeterEvent {
|
||||
identifier: String,
|
||||
}
|
||||
|
||||
impl StripeMeterEvent {
|
||||
pub async fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterEventParams<'_>,
|
||||
) -> Result<Self, stripe::StripeError> {
|
||||
let identifier = params.identifier;
|
||||
match client.post_form("/billing/meter_events", params).await {
|
||||
Ok(event) => Ok(event),
|
||||
Err(stripe::StripeError::Stripe(error)) => {
|
||||
if error.http_status == 400
|
||||
&& error
|
||||
.message
|
||||
.as_ref()
|
||||
.map_or(false, |message| message.contains(identifier))
|
||||
{
|
||||
Ok(Self {
|
||||
identifier: identifier.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(stripe::StripeError::Stripe(error))
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventParams<'a> {
|
||||
identifier: &'a str,
|
||||
event_name: &'a str,
|
||||
payload: StripeCreateMeterEventPayload<'a>,
|
||||
timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventPayload<'a> {
|
||||
value: u64,
|
||||
stripe_customer_id: &'a stripe::CustomerId,
|
||||
}
|
||||
|
||||
fn subscription_contains_price(
|
||||
subscription: &stripe::Subscription,
|
||||
price_id: &stripe::PriceId,
|
||||
) -> bool {
|
||||
subscription.items.data.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id == *price_id)
|
||||
})
|
||||
}
|
||||
@@ -50,7 +50,7 @@ async fn test_channel_guests(
|
||||
project_b.read_with(cx_b, |project, _| project.remote_id()),
|
||||
Some(project_id),
|
||||
);
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(project_b
|
||||
.update(cx_b, |project, cx| {
|
||||
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
|
||||
@@ -103,7 +103,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
workspace.active_item_as::<Editor>(cx).unwrap(),
|
||||
)
|
||||
});
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |e, cx| e.read_only(cx)));
|
||||
assert!(room_b.read_with(cx_b, |room, _| !room.can_use_microphone()));
|
||||
assert!(room_b
|
||||
@@ -127,7 +127,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are now editable
|
||||
assert!(project_b.read_with(cx_b, |project, cx| !project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| !project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| !editor.read_only(cx)));
|
||||
|
||||
// B sees themselves as muted, and can unmute.
|
||||
@@ -153,7 +153,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are no longer editable
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| editor.read_only(cx)));
|
||||
assert!(room_b
|
||||
.update(cx_b, |room, cx| room.share_microphone(cx))
|
||||
|
||||
@@ -262,7 +262,7 @@ async fn test_dev_server_leave_room(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -308,7 +308,7 @@ async fn test_dev_server_delete(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
|
||||
cx1.update(|cx| {
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
@@ -418,12 +418,12 @@ async fn test_dev_server_refresh_access_token(
|
||||
|
||||
// Assert that the other client was disconnected
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
|
||||
// Assert that the owner of the dev server does not see the dev server as online anymore
|
||||
let (workspace, cx1) = client1.active_workspace(cx1);
|
||||
cx1.update(|cx| {
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected(cx));
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected());
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
assert_eq!(
|
||||
store.dev_servers().first().unwrap().status,
|
||||
|
||||
@@ -114,7 +114,7 @@ async fn test_host_disconnect(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_b.read_with(cx_b, |project, cx| project.is_read_only(cx));
|
||||
project_b.read_with(cx_b, |project, _| project.is_read_only());
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ use language::{
|
||||
use live_kit_client::MacOSDisplay;
|
||||
use lsp::LanguageServerId;
|
||||
use parking_lot::Mutex;
|
||||
use project::lsp_store::FormatTarget;
|
||||
use project::{
|
||||
lsp_store::FormatTrigger, search::SearchQuery, search::SearchResult, DiagnosticSummary,
|
||||
HoverBlockKind, Project, ProjectPath,
|
||||
@@ -1390,7 +1389,7 @@ async fn test_unshare_project(
|
||||
.unwrap();
|
||||
executor.run_until_parked();
|
||||
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_disconnected(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_disconnected()));
|
||||
|
||||
// Client C opens the project.
|
||||
let project_c = client_c.join_remote_project(project_id, cx_c).await;
|
||||
@@ -1403,7 +1402,7 @@ async fn test_unshare_project(
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
assert!(project_c.read_with(cx_c, |project, cx| project.is_disconnected(cx)));
|
||||
assert!(project_c.read_with(cx_c, |project, _| project.is_disconnected()));
|
||||
|
||||
// Client C can open the project again after client A re-shares.
|
||||
let project_id = active_call_a
|
||||
@@ -1428,8 +1427,8 @@ async fn test_unshare_project(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_c2.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_c2.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
assert!(project.collaborators().is_empty());
|
||||
});
|
||||
}
|
||||
@@ -1561,8 +1560,8 @@ async fn test_project_reconnect(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
project_b1.read_with(cx_b, |project, _| {
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
@@ -1662,7 +1661,7 @@ async fn test_project_reconnect(
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1696,9 +1695,9 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
project_b2.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(!project.is_disconnected(cx)));
|
||||
project_b3.read_with(cx_b, |project, _| assert!(!project.is_disconnected()));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WaZ"));
|
||||
|
||||
@@ -1755,7 +1754,7 @@ async fn test_project_reconnect(
|
||||
executor.run_until_parked();
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1789,7 +1788,7 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
project_b3.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WXaYZ"));
|
||||
|
||||
@@ -3817,8 +3816,8 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
@@ -3850,12 +3849,12 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 0);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4418,7 +4417,6 @@ async fn test_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4452,7 +4450,6 @@ async fn test_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4558,7 +4555,6 @@ async fn test_prettier_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4578,7 +4574,6 @@ async fn test_prettier_formatting_buffer(
|
||||
HashSet::from_iter([buffer_a.clone()]),
|
||||
true,
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -1168,7 +1168,7 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
Some((project, cx))
|
||||
});
|
||||
|
||||
if !guest_project.is_disconnected(cx) {
|
||||
if !guest_project.is_disconnected() {
|
||||
if let Some((host_project, host_cx)) = host_project {
|
||||
let host_worktree_snapshots =
|
||||
host_project.read_with(host_cx, |host_project, cx| {
|
||||
@@ -1254,8 +1254,8 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
|
||||
let buffers = client.buffers().clone();
|
||||
for (guest_project, guest_buffers) in &buffers {
|
||||
let project_id = if guest_project.read_with(client_cx, |project, cx| {
|
||||
project.is_local() || project.is_disconnected(cx)
|
||||
let project_id = if guest_project.read_with(client_cx, |project, _| {
|
||||
project.is_local() || project.is_disconnected()
|
||||
}) {
|
||||
continue;
|
||||
} else {
|
||||
|
||||
@@ -532,9 +532,9 @@ impl<T: RandomizedTest> TestPlan<T> {
|
||||
server.allow_connections();
|
||||
|
||||
for project in client.dev_server_projects().iter() {
|
||||
project.read_with(&client_cx, |project, cx| {
|
||||
project.read_with(&client_cx, |project, _| {
|
||||
assert!(
|
||||
project.is_disconnected(cx),
|
||||
project.is_disconnected(),
|
||||
"project {:?} should be read only",
|
||||
project.remote_id()
|
||||
)
|
||||
|
||||
@@ -2,12 +2,10 @@ use crate::tests::TestServer;
|
||||
use call::ActiveCall;
|
||||
use fs::{FakeFs, Fs as _};
|
||||
use gpui::{Context as _, TestAppContext};
|
||||
use http_client::BlockedHttpClient;
|
||||
use language::{language_settings::all_language_settings, LanguageRegistry};
|
||||
use node_runtime::NodeRuntime;
|
||||
use language::language_settings::all_language_settings;
|
||||
use project::ProjectPath;
|
||||
use remote::SshRemoteClient;
|
||||
use remote_server::{HeadlessAppState, HeadlessProject};
|
||||
use remote_server::HeadlessProject;
|
||||
use serde_json::json;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
@@ -50,22 +48,8 @@ async fn test_sharing_an_ssh_remote_project(
|
||||
|
||||
// User A connects to the remote project via SSH.
|
||||
server_cx.update(HeadlessProject::init);
|
||||
let remote_http_client = Arc::new(BlockedHttpClient);
|
||||
let node = NodeRuntime::unavailable();
|
||||
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
|
||||
let _headless_project = server_cx.new_model(|cx| {
|
||||
client::init_settings(cx);
|
||||
HeadlessProject::new(
|
||||
HeadlessAppState {
|
||||
session: server_ssh,
|
||||
fs: remote_fs.clone(),
|
||||
http_client: remote_http_client,
|
||||
node_runtime: node,
|
||||
languages,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let _headless_project =
|
||||
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
|
||||
|
||||
let (project_a, worktree_id) = client_a
|
||||
.build_ssh_project("/code/project1", client_ssh, cx_a)
|
||||
|
||||
@@ -635,11 +635,9 @@ impl TestServer {
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
llm_db: None,
|
||||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
stripe_billing: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
clickhouse_client: None,
|
||||
@@ -679,6 +677,7 @@ impl TestServer {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_llm_usage_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
},
|
||||
|
||||
@@ -111,7 +111,7 @@ impl MessageEditor {
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_completion_provider(Some(Box::new(MessageEditorCompletionProvider(this))));
|
||||
editor.set_completion_provider(Box::new(MessageEditorCompletionProvider(this)));
|
||||
editor.set_auto_replace_emoji_shortcode(
|
||||
MessageEditorSettings::get_global(cx)
|
||||
.auto_replace_emoji_shortcode
|
||||
|
||||
@@ -26,7 +26,7 @@ const JSON_RPC_VERSION: &str = "2.0";
|
||||
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncAppContext)>;
|
||||
type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
@@ -94,6 +94,7 @@ enum CspResult<T> {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Notification<'a, T> {
|
||||
jsonrpc: &'static str,
|
||||
id: RequestId,
|
||||
#[serde(borrow)]
|
||||
method: &'a str,
|
||||
params: T,
|
||||
@@ -102,6 +103,7 @@ struct Notification<'a, T> {
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct AnyNotification<'a> {
|
||||
jsonrpc: &'a str,
|
||||
id: RequestId,
|
||||
method: String,
|
||||
#[serde(default)]
|
||||
params: Option<Value>,
|
||||
@@ -244,7 +246,11 @@ impl Client {
|
||||
if let Some(handler) =
|
||||
notification_handlers.get_mut(notification.method.as_str())
|
||||
{
|
||||
handler(notification.params.unwrap_or(Value::Null), cx.clone());
|
||||
handler(
|
||||
notification.id,
|
||||
notification.params.unwrap_or(Value::Null),
|
||||
cx.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -372,8 +378,10 @@ impl Client {
|
||||
/// Sends a notification to the context server without expecting a response.
|
||||
/// This function serializes the notification and sends it through the outbound channel.
|
||||
pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
|
||||
let id = self.next_id.fetch_add(1, SeqCst);
|
||||
let notification = serde_json::to_string(&Notification {
|
||||
jsonrpc: JSON_RPC_VERSION,
|
||||
id: RequestId::Int(id),
|
||||
method,
|
||||
params,
|
||||
})
|
||||
@@ -382,13 +390,13 @@ impl Client {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
pub fn on_notification<F>(&self, method: &'static str, mut f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncAppContext),
|
||||
{
|
||||
self.notification_handlers
|
||||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
.insert(method, Box::new(move |_, params, cx| f(params, cx)));
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
|
||||
@@ -85,7 +85,7 @@ impl ContextServer {
|
||||
)?;
|
||||
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
let client_info = types::EntityInfo {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
|
||||
@@ -11,6 +11,8 @@ use collections::HashMap;
|
||||
use crate::client::Client;
|
||||
use crate::types;
|
||||
|
||||
pub use types::PromptInfo;
|
||||
|
||||
const PROTOCOL_VERSION: u32 = 1;
|
||||
|
||||
pub struct ModelContextProtocol {
|
||||
@@ -24,7 +26,7 @@ impl ModelContextProtocol {
|
||||
|
||||
pub async fn initialize(
|
||||
self,
|
||||
client_info: types::Implementation,
|
||||
client_info: types::EntityInfo,
|
||||
) -> Result<InitializedContextServerProtocol> {
|
||||
let params = types::InitializeParams {
|
||||
protocol_version: PROTOCOL_VERSION,
|
||||
@@ -94,7 +96,7 @@ impl InitializedContextServerProtocol {
|
||||
}
|
||||
|
||||
/// List the MCP prompts.
|
||||
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
|
||||
pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
|
||||
self.check_capability(ServerCapability::Prompts)?;
|
||||
|
||||
let response: types::PromptsListResponse = self
|
||||
@@ -105,18 +107,6 @@ impl InitializedContextServerProtocol {
|
||||
Ok(response.prompts)
|
||||
}
|
||||
|
||||
/// List the MCP resources.
|
||||
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
|
||||
self.check_capability(ServerCapability::Resources)?;
|
||||
|
||||
let response: types::ResourcesListResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::ResourcesList.as_str(), ())
|
||||
.await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Executes a prompt with the given arguments and returns the result.
|
||||
pub async fn run_prompt<P: AsRef<str>>(
|
||||
&self,
|
||||
|
||||
@@ -15,7 +15,6 @@ pub enum RequestType {
|
||||
PromptsGet,
|
||||
PromptsList,
|
||||
CompletionComplete,
|
||||
Ping,
|
||||
}
|
||||
|
||||
impl RequestType {
|
||||
@@ -31,7 +30,6 @@ impl RequestType {
|
||||
RequestType::PromptsGet => "prompts/get",
|
||||
RequestType::PromptsList => "prompts/list",
|
||||
RequestType::CompletionComplete => "completion/complete",
|
||||
RequestType::Ping => "ping",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -41,15 +39,14 @@ impl RequestType {
|
||||
pub struct InitializeParams {
|
||||
pub protocol_version: u32,
|
||||
pub capabilities: ClientCapabilities,
|
||||
pub client_info: Implementation,
|
||||
pub client_info: EntityInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CallToolParams {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<HashMap<String, serde_json::Value>>,
|
||||
pub arguments: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -80,7 +77,6 @@ pub struct LoggingSetLevelParams {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsGetParams {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
@@ -105,13 +101,6 @@ pub struct PromptReference {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: PromptReferenceType,
|
||||
pub uri: Url,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PromptReferenceType {
|
||||
@@ -121,6 +110,13 @@ pub enum PromptReferenceType {
|
||||
Resource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionArgument {
|
||||
@@ -133,7 +129,7 @@ pub struct CompletionArgument {
|
||||
pub struct InitializeResponse {
|
||||
pub protocol_version: u32,
|
||||
pub capabilities: ServerCapabilities,
|
||||
pub server_info: Implementation,
|
||||
pub server_info: EntityInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -145,39 +141,13 @@ pub struct ResourcesReadResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesListResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resource_templates: Option<Vec<ResourceTemplate>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resources: Option<Vec<Resource>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SamplingMessage {
|
||||
pub role: SamplingRole,
|
||||
pub content: SamplingContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SamplingRole {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SamplingContent {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { data: String, mime_type: String },
|
||||
pub resources: Vec<Resource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsGetResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub prompt: String,
|
||||
}
|
||||
@@ -185,7 +155,7 @@ pub struct PromptsGetResponse {
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsListResponse {
|
||||
pub prompts: Vec<Prompt>,
|
||||
pub prompts: Vec<PromptInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -198,91 +168,61 @@ pub struct CompletionCompleteResponse {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionResult {
|
||||
pub values: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub has_more: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Prompt {
|
||||
pub struct PromptInfo {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<Vec<PromptArgument>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptArgument {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
// Shared Types
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub experimental: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling: Option<serde_json::Value>,
|
||||
pub sampling: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub experimental: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logging: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompts: Option<PromptsCapabilities>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logging: Option<HashMap<String, serde_json::Value>>,
|
||||
pub prompts: Option<HashMap<String, serde_json::Value>>,
|
||||
pub resources: Option<ResourcesCapabilities>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<ToolsCapabilities>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list_changed: Option<bool>,
|
||||
pub tools: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourcesCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub subscribe: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolsCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Implementation {
|
||||
pub struct EntityInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
@@ -291,10 +231,6 @@ pub struct Implementation {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Resource {
|
||||
pub uri: Url,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
@@ -302,23 +238,17 @@ pub struct Resource {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceContent {
|
||||
pub uri: Url,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blob: Option<String>,
|
||||
pub data: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceTemplate {
|
||||
pub uri_template: String,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -330,16 +260,13 @@ pub enum LoggingLevel {
|
||||
Error,
|
||||
}
|
||||
|
||||
// Client Notifications
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum NotificationType {
|
||||
Initialized,
|
||||
Progress,
|
||||
Message,
|
||||
ResourcesUpdated,
|
||||
ResourcesListChanged,
|
||||
ToolsListChanged,
|
||||
PromptsListChanged,
|
||||
}
|
||||
|
||||
impl NotificationType {
|
||||
@@ -347,11 +274,6 @@ impl NotificationType {
|
||||
match self {
|
||||
NotificationType::Initialized => "notifications/initialized",
|
||||
NotificationType::Progress => "notifications/progress",
|
||||
NotificationType::Message => "notifications/message",
|
||||
NotificationType::ResourcesUpdated => "notifications/resources/updated",
|
||||
NotificationType::ResourcesListChanged => "notifications/resources/list_changed",
|
||||
NotificationType::ToolsListChanged => "notifications/tools/list_changed",
|
||||
NotificationType::PromptsListChanged => "notifications/prompts/list_changed",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,13 +288,12 @@ pub enum ClientNotification {
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProgressParams {
|
||||
pub progress_token: ProgressToken,
|
||||
pub progress_token: String,
|
||||
pub progress: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
pub type ProgressToken = String;
|
||||
// Helper Types that don't map directly to the protocol
|
||||
|
||||
pub enum CompletionTotal {
|
||||
Exact(u32),
|
||||
|
||||
@@ -237,7 +237,6 @@ gpui::actions!(
|
||||
ToggleFold,
|
||||
ToggleFoldRecursive,
|
||||
Format,
|
||||
FormatSelections,
|
||||
GoToDeclaration,
|
||||
GoToDeclarationSplit,
|
||||
GoToDefinition,
|
||||
@@ -295,7 +294,6 @@ gpui::actions!(
|
||||
RevealInFileManager,
|
||||
ReverseLines,
|
||||
RevertFile,
|
||||
ReloadFile,
|
||||
RevertSelectedHunks,
|
||||
Rewrap,
|
||||
ScrollCursorBottom,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user