Compare commits
143 Commits
v0.157.4
...
v0.158.1-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
363530ea84 | ||
|
|
7ce18afabf | ||
|
|
f9beb25fb0 | ||
|
|
474e670bbd | ||
|
|
fe0bcc063c | ||
|
|
69abe71bf7 | ||
|
|
9c3d80d6e8 | ||
|
|
834d50f0db | ||
|
|
bcdb10b3cb | ||
|
|
598939d186 | ||
|
|
9d944d0662 | ||
|
|
7d2628e805 | ||
|
|
84df3a0cad | ||
|
|
eb76065ad3 | ||
|
|
84018d7a2d | ||
|
|
57c55b32e1 | ||
|
|
a4357c429a | ||
|
|
103665ee28 | ||
|
|
2f960c4aba | ||
|
|
109ebc5f27 | ||
|
|
eddf70b5c4 | ||
|
|
128619899e | ||
|
|
a77ec94cbc | ||
|
|
a56f946a7d | ||
|
|
f944ebc4cb | ||
|
|
1dda039f38 | ||
|
|
182230a0ba | ||
|
|
b64919aa11 | ||
|
|
b752548742 | ||
|
|
d63a49647f | ||
|
|
c00f2d8842 | ||
|
|
973143fa35 | ||
|
|
d806df9f16 | ||
|
|
b682fc6d1d | ||
|
|
5445f898e8 | ||
|
|
56163b1e35 | ||
|
|
695176898e | ||
|
|
e3c6ba4bd7 | ||
|
|
8924b3fb5b | ||
|
|
1732441f48 | ||
|
|
096d37a1ed | ||
|
|
ba69f48ccf | ||
|
|
0eb6bfd323 | ||
|
|
4fa75a78b9 | ||
|
|
397e4bee0a | ||
|
|
db7417f3b5 | ||
|
|
be7b24fcf7 | ||
|
|
62de03f286 | ||
|
|
41ba4178fc | ||
|
|
6e2869a321 | ||
|
|
6986f081d0 | ||
|
|
3ff52a816e | ||
|
|
7d5fe66b54 | ||
|
|
792f583b97 | ||
|
|
6ec00cdb06 | ||
|
|
71a878aa39 | ||
|
|
f2337bbed1 | ||
|
|
fcf9e546da | ||
|
|
7dc069100d | ||
|
|
5b207ba238 | ||
|
|
325f106c8b | ||
|
|
ec5d6e96bb | ||
|
|
54683ff2b9 | ||
|
|
cdead5760a | ||
|
|
39468de8c6 | ||
|
|
6491148196 | ||
|
|
0b10fd5098 | ||
|
|
74cc90887a | ||
|
|
875c0cb09f | ||
|
|
aefc559f43 | ||
|
|
bebe24ea77 | ||
|
|
f73a076a63 | ||
|
|
b2e844f2ec | ||
|
|
9e14fd915f | ||
|
|
c85a3cc117 | ||
|
|
22ac178f9d | ||
|
|
c709b66f35 | ||
|
|
b739cfa73f | ||
|
|
0fc3072362 | ||
|
|
3cbaa08d89 | ||
|
|
12c9f0f723 | ||
|
|
f280b29859 | ||
|
|
550064f80f | ||
|
|
f33b8abc72 | ||
|
|
22ea7cef7a | ||
|
|
f1c45d988e | ||
|
|
84b61c8b1a | ||
|
|
5cf0217549 | ||
|
|
c21f26c419 | ||
|
|
d976c5f1b6 | ||
|
|
79ed217e42 | ||
|
|
0a7468c89f | ||
|
|
518f8cc5b7 | ||
|
|
ccaf3268f8 | ||
|
|
1691652948 | ||
|
|
4726f30bd6 | ||
|
|
36b9e40085 | ||
|
|
e962839d13 | ||
|
|
eea600ecc3 | ||
|
|
3c6989323f | ||
|
|
596d8b2fe3 | ||
|
|
9bc4e3b4ae | ||
|
|
972886c29e | ||
|
|
d2b4fa20ef | ||
|
|
21c27cecba | ||
|
|
4de05d18ed | ||
|
|
8c9a05b2a8 | ||
|
|
348e317695 | ||
|
|
281c60f12d | ||
|
|
6859482020 | ||
|
|
7c306a5a0e | ||
|
|
b75532fad7 | ||
|
|
e3ff2ced79 | ||
|
|
5841ac406d | ||
|
|
f6f5ad138d | ||
|
|
db50467bbc | ||
|
|
fe1078ef68 | ||
|
|
5cf4ac16d6 | ||
|
|
05b2010db5 | ||
|
|
d8484c57e1 | ||
|
|
fcfd769b39 | ||
|
|
285fb51771 | ||
|
|
ed484ecf5f | ||
|
|
ab34342664 | ||
|
|
53cc82b132 | ||
|
|
cae548a50d | ||
|
|
69711660ab | ||
|
|
b2e1572820 | ||
|
|
66ea96839a | ||
|
|
3db789ed90 | ||
|
|
99a6a3d5e3 | ||
|
|
d316577fd5 | ||
|
|
711180981b | ||
|
|
49c75eb062 | ||
|
|
f1053ff525 | ||
|
|
817a41c4dc | ||
|
|
bc23d1e666 | ||
|
|
bc4abd2b29 | ||
|
|
71f4ca67c2 | ||
|
|
f05b440572 | ||
|
|
1cbaca667f | ||
|
|
8911fd46e1 | ||
|
|
926e54bd4a |
2
.github/workflows/bump_collab_staging.yml
vendored
2
.github/workflows/bump_collab_staging.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/bump_patch_version.yml
vendored
2
.github/workflows/bump_patch_version.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch }}
|
||||
ssh-key: ${{ secrets.ZED_BOT_DEPLOY_KEY }}
|
||||
|
||||
61
.github/workflows/ci.yml
vendored
61
.github/workflows/ci.yml
vendored
@@ -34,7 +34,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -96,10 +96,13 @@ jobs:
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build collab
|
||||
run: cargo build -p collab
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p collab
|
||||
|
||||
- name: Build other binaries and features
|
||||
run: cargo build --workspace --bins --all-features; cargo check -p gpui --features "macos-blade"
|
||||
run: |
|
||||
RUSTFLAGS="-D warnings" cargo build --workspace --bins --all-features
|
||||
cargo check -p gpui --features "macos-blade"
|
||||
RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
|
||||
linux_tests:
|
||||
timeout-minutes: 60
|
||||
@@ -111,7 +114,7 @@ jobs:
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -131,7 +134,33 @@ jobs:
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build Zed
|
||||
run: cargo build -p zed
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p zed
|
||||
|
||||
build_remote_server:
|
||||
timeout-minutes: 60
|
||||
name: (Linux) Build Remote Server
|
||||
runs-on:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
cache-provider: "buildjet"
|
||||
|
||||
- name: Install Clang & Mold
|
||||
run: ./script/remote-server && ./script/install-mold 2.34.0
|
||||
|
||||
- name: Build Remote Server
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
|
||||
# todo(windows): Actually run the tests
|
||||
windows_tests:
|
||||
@@ -140,7 +169,7 @@ jobs:
|
||||
runs-on: hosted-windows-1
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -155,7 +184,7 @@ jobs:
|
||||
run: cargo xtask clippy
|
||||
|
||||
- name: Build Zed
|
||||
run: cargo build -p zed
|
||||
run: $env:RUSTFLAGS="-D warnings"; cargo build
|
||||
|
||||
bundle-mac:
|
||||
timeout-minutes: 60
|
||||
@@ -181,7 +210,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
# We need to fetch more than one commit so that `script/draft-release-notes`
|
||||
# is able to diff between the current and previous tag.
|
||||
@@ -219,20 +248,20 @@ jobs:
|
||||
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
|
||||
|
||||
- name: Upload app bundle (universal) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
|
||||
path: target/release/Zed.dmg
|
||||
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
|
||||
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
|
||||
|
||||
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
|
||||
@@ -266,7 +295,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -283,7 +312,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
|
||||
@@ -313,7 +342,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -330,7 +359,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
|
||||
|
||||
2
.github/workflows/danger.yml
vendored
2
.github/workflows/danger.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
|
||||
10
.github/workflows/deploy_cloudflare.yml
vendored
10
.github/workflows/deploy_cloudflare.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -36,28 +36,28 @@ jobs:
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Deploy Docs
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: pages deploy target/deploy --project-name=docs
|
||||
|
||||
- name: Deploy Install
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
|
||||
|
||||
- name: Deploy Docs Workers
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: deploy .cloudflare/docs-proxy/src/worker.js
|
||||
|
||||
- name: Deploy Install Workers
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
|
||||
8
.github/workflows/deploy_collab.yml
vendored
8
.github/workflows/deploy_collab.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
run: doctl registry login
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
|
||||
2
.github/workflows/publish_extension_cli.yml
vendored
2
.github/workflows/publish_extension_cli.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
- ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/randomized_tests.yml
vendored
2
.github/workflows/randomized_tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/release_actions.yml
vendored
2
.github/workflows/release_actions.yml
vendored
@@ -1,3 +1,5 @@
|
||||
name: Release Actions
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
12
.github/workflows/release_nightly.yml
vendored
12
.github/workflows/release_nightly.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -182,7 +182,7 @@ jobs:
|
||||
- bundle-linux-arm
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
name: Update All Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
@@ -8,11 +10,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
name: Update Weekly Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 15 * * *"
|
||||
@@ -8,11 +10,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
|
||||
781
Cargo.lock
generated
781
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
24
Cargo.toml
24
Cargo.toml
@@ -52,7 +52,6 @@ members = [
|
||||
"crates/indexed_docs",
|
||||
"crates/inline_completion_button",
|
||||
"crates/install_cli",
|
||||
"crates/isahc_http_client",
|
||||
"crates/journal",
|
||||
"crates/language",
|
||||
"crates/language_model",
|
||||
@@ -88,6 +87,7 @@ members = [
|
||||
"crates/remote",
|
||||
"crates/remote_server",
|
||||
"crates/repl",
|
||||
"crates/reqwest_client",
|
||||
"crates/rich_text",
|
||||
"crates/rope",
|
||||
"crates/rpc",
|
||||
@@ -122,6 +122,7 @@ members = [
|
||||
"crates/ui",
|
||||
"crates/ui_input",
|
||||
"crates/ui_macros",
|
||||
"crates/reqwest_client",
|
||||
"crates/util",
|
||||
"crates/vcs_menu",
|
||||
"crates/vim",
|
||||
@@ -144,7 +145,6 @@ members = [
|
||||
"extensions/elm",
|
||||
"extensions/emmet",
|
||||
"extensions/erlang",
|
||||
"extensions/gleam",
|
||||
"extensions/glsl",
|
||||
"extensions/haskell",
|
||||
"extensions/html",
|
||||
@@ -156,7 +156,6 @@ members = [
|
||||
"extensions/proto",
|
||||
"extensions/purescript",
|
||||
"extensions/ruff",
|
||||
"extensions/ruby",
|
||||
"extensions/slash-commands-example",
|
||||
"extensions/snippets",
|
||||
"extensions/svelte",
|
||||
@@ -220,7 +219,7 @@ git = { path = "crates/git" }
|
||||
git_hosting_providers = { path = "crates/git_hosting_providers" }
|
||||
go_to_line = { path = "crates/go_to_line" }
|
||||
google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui = { path = "crates/gpui", default-features = false, features = ["http_client"]}
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
headless = { path = "crates/headless" }
|
||||
html_to_markdown = { path = "crates/html_to_markdown" }
|
||||
@@ -229,7 +228,6 @@ image_viewer = { path = "crates/image_viewer" }
|
||||
indexed_docs = { path = "crates/indexed_docs" }
|
||||
inline_completion_button = { path = "crates/inline_completion_button" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
isahc_http_client = { path = "crates/isahc_http_client" }
|
||||
journal = { path = "crates/journal" }
|
||||
language = { path = "crates/language" }
|
||||
language_model = { path = "crates/language_model" }
|
||||
@@ -266,6 +264,7 @@ release_channel = { path = "crates/release_channel" }
|
||||
remote = { path = "crates/remote" }
|
||||
remote_server = { path = "crates/remote_server" }
|
||||
repl = { path = "crates/repl" }
|
||||
reqwest_client = { path = "crates/reqwest_client" }
|
||||
rich_text = { path = "crates/rich_text" }
|
||||
rope = { path = "crates/rope" }
|
||||
rpc = { path = "crates/rpc" }
|
||||
@@ -327,7 +326,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.5.0"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.23"
|
||||
async-tungstenite = "0.24"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
base64 = "0.22"
|
||||
@@ -336,6 +335,7 @@ blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.18"
|
||||
cargo_toml = "0.20"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
@@ -367,10 +367,6 @@ ignore = "0.4.22"
|
||||
image = "0.25.1"
|
||||
indexmap = { version = "1.6.2", features = ["serde"] }
|
||||
indoc = "2"
|
||||
# We explicitly disable http2 support in isahc.
|
||||
isahc = { version = "1.7.2", default-features = false, features = [
|
||||
"text-decoding",
|
||||
] }
|
||||
itertools = "0.13.0"
|
||||
jsonwebtoken = "9.3"
|
||||
libc = "0.2"
|
||||
@@ -395,6 +391,7 @@ pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29", default-features = false, features = ["charset", "http2", "macos-system-configuration", "rustls-tls-native-roots", "stream"]}
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { version = "0.15", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
@@ -439,7 +436,7 @@ time = { version = "0.3", features = [
|
||||
] }
|
||||
tiny_http = "0.8"
|
||||
toml = "0.8"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio = { version = "1" }
|
||||
tower-http = "0.4.4"
|
||||
tree-sitter = { version = "0.23", features = ["wasm"] }
|
||||
tree-sitter-bash = "0.23"
|
||||
@@ -452,6 +449,7 @@ tree-sitter-go = "0.23"
|
||||
tree-sitter-go-mod = { git = "https://github.com/zed-industries/tree-sitter-go-mod", rev = "a9aea5e358cde4d0f8ff20b7bc4fa311e359c7ca", package = "tree-sitter-gomod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/zed-industries/tree-sitter-go-work", rev = "acb0617bf7f4fda02c6217676cc64acb89536dc7" }
|
||||
tree-sitter-heex = { git = "https://github.com/zed-industries/tree-sitter-heex", rev = "1dd45142fbb05562e35b2040c6129c9bca346592" }
|
||||
tree-sitter-diff = "0.1.0"
|
||||
tree-sitter-html = "0.20"
|
||||
tree-sitter-jsdoc = "0.23"
|
||||
tree-sitter-json = "0.23"
|
||||
@@ -479,9 +477,11 @@ wasmtime = { version = "24", default-features = false, features = [
|
||||
wasmtime-wasi = "24"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
version = "0.39"
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
|
||||
2
Cross.toml
Normal file
2
Cross.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[build]
|
||||
dockerfile = "Dockerfile-cross"
|
||||
@@ -13,30 +13,9 @@ ARG GITHUB_SHA
|
||||
|
||||
ENV GITHUB_SHA=$GITHUB_SHA
|
||||
|
||||
# At some point in the past 3 weeks, additional dependencies on `xkbcommon` and
|
||||
# `xkbcommon-x11` were introduced into collab.
|
||||
#
|
||||
# A `git bisect` points to this commit as being the culprit: `b8e6098f60e5dabe98fe8281f993858dacc04a55`.
|
||||
#
|
||||
# Now when we try to build collab for the Docker image, it fails with the following
|
||||
# error:
|
||||
#
|
||||
# ```
|
||||
# 985.3 = note: /usr/bin/ld: cannot find -lxkbcommon: No such file or directory
|
||||
# 985.3 /usr/bin/ld: cannot find -lxkbcommon-x11: No such file or directory
|
||||
# 985.3 collect2: error: ld returned 1 exit status
|
||||
# ```
|
||||
#
|
||||
# The last successful deploys were at:
|
||||
# - Staging: `4f408ec65a3867278322a189b4eb20f1ab51f508`
|
||||
# - Production: `fc4c533d0a8c489e5636a4249d2b52a80039fbd7`
|
||||
#
|
||||
# Also add `cmake`, since we need it to build `wasmtime`.
|
||||
#
|
||||
# Installing these as a temporary workaround, but I think ideally we'd want to figure
|
||||
# out what caused them to be included in the first place.
|
||||
RUN apt-get update; \
|
||||
apt-get install -y --no-install-recommends libxkbcommon-dev libxkbcommon-x11-dev cmake
|
||||
apt-get install -y --no-install-recommends cmake
|
||||
|
||||
RUN --mount=type=cache,target=./script/node_modules \
|
||||
--mount=type=cache,target=/usr/local/cargo/registry \
|
||||
|
||||
17
Dockerfile-cross
Normal file
17
Dockerfile-cross
Normal file
@@ -0,0 +1,17 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG CROSS_BASE_IMAGE
|
||||
FROM ${CROSS_BASE_IMAGE}
|
||||
WORKDIR /app
|
||||
ARG TZ=Etc/UTC \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV CARGO_TERM_COLOR=always
|
||||
|
||||
COPY script/install-mold script/
|
||||
RUN ./script/install-mold "2.34.0"
|
||||
COPY script/remote-server script/
|
||||
RUN ./script/remote-server
|
||||
|
||||
COPY . .
|
||||
16
Dockerfile-cross.dockerignore
Normal file
16
Dockerfile-cross.dockerignore
Normal file
@@ -0,0 +1,16 @@
|
||||
.git
|
||||
.github
|
||||
**/.gitignore
|
||||
**/.gitkeep
|
||||
.gitattributes
|
||||
.mailmap
|
||||
**/target
|
||||
zed.xcworkspace
|
||||
.DS_Store
|
||||
compose.yml
|
||||
plugins/bin
|
||||
script/node_modules
|
||||
styles/node_modules
|
||||
crates/collab/static/styles.css
|
||||
vendor/bin
|
||||
assets/themes/
|
||||
@@ -20,6 +20,7 @@
|
||||
"bashrc": "terminal",
|
||||
"bmp": "image",
|
||||
"c": "c",
|
||||
"c++": "cpp",
|
||||
"cc": "cpp",
|
||||
"cjs": "javascript",
|
||||
"coffee": "coffeescript",
|
||||
@@ -27,6 +28,7 @@
|
||||
"cpp": "cpp",
|
||||
"css": "css",
|
||||
"csv": "storage",
|
||||
"cxx": "cpp",
|
||||
"cts": "typescript",
|
||||
"dart": "dart",
|
||||
"dat": "storage",
|
||||
@@ -66,11 +68,13 @@
|
||||
"heex": "elixir",
|
||||
"heic": "image",
|
||||
"heif": "image",
|
||||
"hh": "cpp",
|
||||
"hpp": "cpp",
|
||||
"hrl": "erlang",
|
||||
"hs": "haskell",
|
||||
"htm": "template",
|
||||
"html": "template",
|
||||
"hxx": "cpp",
|
||||
"ib": "storage",
|
||||
"ico": "image",
|
||||
"ini": "settings",
|
||||
|
||||
@@ -664,7 +664,8 @@
|
||||
"shift-up": "terminal::ScrollLineUp",
|
||||
"shift-down": "terminal::ScrollLineDown",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -395,6 +395,7 @@
|
||||
// Change the default action on `menu::Confirm` by setting the parameter
|
||||
// "alt-cmd-o": ["projects::OpenRecent", {"create_new_window": true }],
|
||||
"alt-cmd-o": "projects::OpenRecent",
|
||||
"ctrl-cmd-o": "projects::OpenRemote",
|
||||
"alt-cmd-b": "branches::OpenRecent",
|
||||
"ctrl-~": "workspace::NewTerminal",
|
||||
"cmd-s": "workspace::Save",
|
||||
@@ -678,7 +679,8 @@
|
||||
"cmd-home": "terminal::ScrollToTop",
|
||||
"cmd-end": "terminal::ScrollToBottom",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -128,6 +128,10 @@
|
||||
"shift-m": "vim::WindowMiddle",
|
||||
"shift-l": "vim::WindowBottom",
|
||||
// z commands
|
||||
"z enter": ["workspace::SendKeystrokes", "z t ^"],
|
||||
"z -": ["workspace::SendKeystrokes", "z b ^"],
|
||||
"z ^": ["workspace::SendKeystrokes", "shift-h k z b ^"],
|
||||
"z +": ["workspace::SendKeystrokes", "shift-l j z t ^"],
|
||||
"z t": "editor::ScrollCursorTop",
|
||||
"z z": "editor::ScrollCursorCenter",
|
||||
"z .": ["workspace::SendKeystrokes", "z z ^"],
|
||||
@@ -252,6 +256,7 @@
|
||||
"@": ["vim::PushOperator", "ReplayRegister"],
|
||||
"ctrl-pagedown": "pane::ActivateNextItem",
|
||||
"ctrl-pageup": "pane::ActivatePrevItem",
|
||||
"insert": "vim::InsertBefore",
|
||||
// tree-sitter related commands
|
||||
"[ x": "editor::SelectLargerSyntaxNode",
|
||||
"] x": "editor::SelectSmallerSyntaxNode",
|
||||
@@ -334,7 +339,8 @@
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"ctrl-r": ["vim::PushOperator", "Register"]
|
||||
"ctrl-r": ["vim::PushOperator", "Register"],
|
||||
"insert": "vim::ToggleReplace"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -353,7 +359,8 @@
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"backspace": "vim::UndoReplace",
|
||||
"tab": "vim::Tab",
|
||||
"enter": "vim::Enter"
|
||||
"enter": "vim::Enter",
|
||||
"insert": "vim::InsertBefore"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -118,8 +118,8 @@
|
||||
// "bar"
|
||||
// 2. A block that surrounds the following character
|
||||
// "block"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -494,7 +494,14 @@
|
||||
// Position of the close button on the editor tabs.
|
||||
"close_position": "right",
|
||||
// Whether to show the file icon for a tab.
|
||||
"file_icons": false
|
||||
"file_icons": false,
|
||||
// What to do after closing the current tab.
|
||||
//
|
||||
// 1. Activate the tab that was open previously (default)
|
||||
// "History"
|
||||
// 2. Activate the neighbour tab (prefers the right one, if present)
|
||||
// "Neighbour"
|
||||
"activate_on_close": "history"
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -684,8 +691,8 @@
|
||||
// "block"
|
||||
// 2. A vertical bar
|
||||
// "bar"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -817,6 +824,7 @@
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
"Astro": {
|
||||
"language_servers": ["astro-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
@@ -843,6 +851,10 @@
|
||||
"Dart": {
|
||||
"tab_size": 2
|
||||
},
|
||||
"Diff": {
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
|
||||
7
assets/settings/initial_server_settings.json
Normal file
7
assets/settings/initial_server_settings.json
Normal file
@@ -0,0 +1,7 @@
|
||||
// Server-specific settings
|
||||
//
|
||||
// For a full list of overridable settings, and general information on settings,
|
||||
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
|
||||
{
|
||||
"lsp": {}
|
||||
}
|
||||
@@ -101,6 +101,7 @@ impl ActivityIndicator {
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
buffer.set_capability(language::Capability::ReadOnly, cx);
|
||||
})?;
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
workspace.add_item_to_active_pane(
|
||||
|
||||
@@ -26,6 +26,3 @@ serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::{
|
||||
PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata,
|
||||
SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, WorkflowStepResolution,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use client::{proto, Client, Status};
|
||||
@@ -697,7 +697,9 @@ impl AssistantPanel {
|
||||
log::error!("no context found with ID: {}", context_id.to_proto());
|
||||
return;
|
||||
};
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
|
||||
let assistant_panel = cx.view().downgrade();
|
||||
let editor = cx.new_view(|cx| {
|
||||
@@ -971,7 +973,8 @@ impl AssistantPanel {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let workspace = this.workspace.clone();
|
||||
let project = this.project.clone();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err();
|
||||
let lsp_adapter_delegate =
|
||||
make_lsp_adapter_delegate(&project, cx).log_err().flatten();
|
||||
|
||||
let fs = this.fs.clone();
|
||||
let project = this.project.clone();
|
||||
@@ -1001,7 +1004,9 @@ impl AssistantPanel {
|
||||
None
|
||||
} else {
|
||||
let context = self.context_store.update(cx, |store, cx| store.create(cx));
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
|
||||
let assistant_panel = cx.view().downgrade();
|
||||
let editor = cx.new_view(|cx| {
|
||||
@@ -1207,7 +1212,7 @@ impl AssistantPanel {
|
||||
let project = self.project.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err().flatten();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let context = context.await?;
|
||||
@@ -1254,7 +1259,9 @@ impl AssistantPanel {
|
||||
.update(cx, |store, cx| store.open_remote_context(id, cx));
|
||||
let fs = self.fs.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let context = context.await?;
|
||||
@@ -1496,6 +1503,13 @@ struct WorkflowAssist {
|
||||
|
||||
type MessageHeader = MessageMetadata;
|
||||
|
||||
#[derive(Clone)]
|
||||
enum AssistError {
|
||||
PaymentRequired,
|
||||
MaxMonthlySpendReached,
|
||||
Message(SharedString),
|
||||
}
|
||||
|
||||
pub struct ContextEditor {
|
||||
context: Model<Context>,
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -1514,7 +1528,7 @@ pub struct ContextEditor {
|
||||
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
|
||||
active_workflow_step: Option<ActiveWorkflowStep>,
|
||||
assistant_panel: WeakView<AssistantPanel>,
|
||||
error_message: Option<SharedString>,
|
||||
last_error: Option<AssistError>,
|
||||
show_accept_terms: bool,
|
||||
pub(crate) slash_menu_handle:
|
||||
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
|
||||
@@ -1553,7 +1567,7 @@ impl ContextEditor {
|
||||
editor.set_show_runnables(false, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_completion_provider(Box::new(completion_provider));
|
||||
editor.set_completion_provider(Some(Box::new(completion_provider)));
|
||||
editor.set_collaboration_hub(Box::new(project.clone()));
|
||||
editor
|
||||
});
|
||||
@@ -1585,7 +1599,7 @@ impl ContextEditor {
|
||||
workflow_steps: HashMap::default(),
|
||||
active_workflow_step: None,
|
||||
assistant_panel,
|
||||
error_message: None,
|
||||
last_error: None,
|
||||
show_accept_terms: false,
|
||||
slash_menu_handle: Default::default(),
|
||||
dragged_file_worktrees: Vec::new(),
|
||||
@@ -1629,7 +1643,7 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
if !self.apply_active_workflow_step(cx) {
|
||||
self.error_message = None;
|
||||
self.last_error = None;
|
||||
self.send_to_model(cx);
|
||||
cx.notify();
|
||||
}
|
||||
@@ -1779,7 +1793,7 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
|
||||
self.error_message = None;
|
||||
self.last_error = None;
|
||||
|
||||
if self
|
||||
.context
|
||||
@@ -2284,7 +2298,13 @@ impl ContextEditor {
|
||||
}
|
||||
ContextEvent::Operation(_) => {}
|
||||
ContextEvent::ShowAssistError(error_message) => {
|
||||
self.error_message = Some(error_message.clone());
|
||||
self.last_error = Some(AssistError::Message(error_message.clone()));
|
||||
}
|
||||
ContextEvent::ShowPaymentRequiredError => {
|
||||
self.last_error = Some(AssistError::PaymentRequired);
|
||||
}
|
||||
ContextEvent::ShowMaxMonthlySpendReachedError => {
|
||||
self.last_error = Some(AssistError::MaxMonthlySpendReached);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4298,6 +4318,154 @@ impl ContextEditor {
|
||||
focus_handle.dispatch_action(&Assist, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
|
||||
let last_error = self.last_error.as_ref()?;
|
||||
|
||||
Some(
|
||||
div()
|
||||
.absolute()
|
||||
.right_3()
|
||||
.bottom_12()
|
||||
.max_w_96()
|
||||
.py_2()
|
||||
.px_3()
|
||||
.elevation_2(cx)
|
||||
.occlude()
|
||||
.child(match last_error {
|
||||
AssistError::PaymentRequired => self.render_payment_required_error(cx),
|
||||
AssistError::MaxMonthlySpendReached => {
|
||||
self.render_max_monthly_spend_reached_error(cx)
|
||||
}
|
||||
AssistError::Message(error_message) => {
|
||||
self.render_assist_error(error_message, cx)
|
||||
}
|
||||
})
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
|
||||
const ACCOUNT_URL: &str = "https://zed.dev/account";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.open_url(ACCOUNT_URL);
|
||||
cx.notify();
|
||||
},
|
||||
)))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
|
||||
const ACCOUNT_URL: &str = "https://zed.dev/account";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.open_url(ACCOUNT_URL);
|
||||
cx.notify();
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_assist_error(
|
||||
&self,
|
||||
error_message: &SharedString,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> AnyElement {
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(
|
||||
Label::new("Error interacting with language model")
|
||||
.weight(FontWeight::MEDIUM),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message.clone())),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
|
||||
@@ -4434,48 +4602,7 @@ impl Render for ContextEditor {
|
||||
.child(element),
|
||||
)
|
||||
})
|
||||
.when_some(self.error_message.clone(), |this, error_message| {
|
||||
this.child(
|
||||
div()
|
||||
.absolute()
|
||||
.right_3()
|
||||
.bottom_12()
|
||||
.max_w_96()
|
||||
.py_2()
|
||||
.px_3()
|
||||
.elevation_2(cx)
|
||||
.occlude()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(
|
||||
Label::new("Error interacting with language model")
|
||||
.weight(FontWeight::MEDIUM),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message)),
|
||||
)
|
||||
.child(h_flex().justify_end().mt_1().child(
|
||||
Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.error_message = None;
|
||||
cx.notify();
|
||||
},
|
||||
)),
|
||||
)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.children(self.render_last_error(cx))
|
||||
.child(
|
||||
h_flex().w_full().relative().child(
|
||||
h_flex()
|
||||
@@ -5505,22 +5632,21 @@ fn render_docs_slash_command_trailer(
|
||||
fn make_lsp_adapter_delegate(
|
||||
project: &Model<Project>,
|
||||
cx: &mut AppContext,
|
||||
) -> Result<Arc<dyn LspAdapterDelegate>> {
|
||||
) -> Result<Option<Arc<dyn LspAdapterDelegate>>> {
|
||||
project.update(cx, |project, cx| {
|
||||
// TODO: Find the right worktree.
|
||||
let worktree = project
|
||||
.worktrees(cx)
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no worktrees when constructing LocalLspAdapterDelegate"))?;
|
||||
let Some(worktree) = project.worktrees(cx).next() else {
|
||||
return Ok(None::<Arc<dyn LspAdapterDelegate>>);
|
||||
};
|
||||
let http_client = project.client().http_client().clone();
|
||||
project.lsp_store().update(cx, |lsp_store, cx| {
|
||||
Ok(LocalLspAdapterDelegate::new(
|
||||
Ok(Some(LocalLspAdapterDelegate::new(
|
||||
lsp_store,
|
||||
&worktree,
|
||||
http_client,
|
||||
project.fs().clone(),
|
||||
cx,
|
||||
) as Arc<dyn LspAdapterDelegate>)
|
||||
) as Arc<dyn LspAdapterDelegate>))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ use gpui::{
|
||||
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
use language_model::{
|
||||
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||
@@ -294,6 +295,8 @@ impl ContextOperation {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContextEvent {
|
||||
ShowAssistError(SharedString),
|
||||
ShowPaymentRequiredError,
|
||||
ShowMaxMonthlySpendReachedError,
|
||||
MessagesEdited,
|
||||
SummaryChanged,
|
||||
StreamedCompletion,
|
||||
@@ -2112,25 +2115,36 @@ impl Context {
|
||||
let result = stream_completion.await;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let error_message = result
|
||||
.as_ref()
|
||||
.err()
|
||||
.map(|error| error.to_string().trim().to_string());
|
||||
|
||||
if let Some(error_message) = error_message.as_ref() {
|
||||
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
||||
error_message.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
if let Some(error_message) = error_message.as_ref() {
|
||||
metadata.status =
|
||||
MessageStatus::Error(SharedString::from(error_message.clone()));
|
||||
let error_message = if let Some(error) = result.as_ref().err() {
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ContextEvent::ShowPaymentRequiredError);
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
} else {
|
||||
metadata.status = MessageStatus::Done;
|
||||
let error_message = error.to_string().trim().to_string();
|
||||
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
||||
error_message.clone(),
|
||||
)));
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status =
|
||||
MessageStatus::Error(SharedString::from(error_message.clone()));
|
||||
});
|
||||
Some(error_message)
|
||||
}
|
||||
});
|
||||
} else {
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Done;
|
||||
});
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let language_name = this
|
||||
@@ -2146,7 +2160,7 @@ impl Context {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -267,7 +267,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: buffer.language().map(|language| language.name()),
|
||||
language_name: buffer.language().map(|language| language.name().to_proto()),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -788,7 +788,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -2278,7 +2278,7 @@ impl InlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
@@ -2954,7 +2954,7 @@ impl CodegenAlternative {
|
||||
model_provider: model_provider_id.to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -521,9 +521,9 @@ impl PromptLibrary {
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_use_modal_editing(false);
|
||||
editor.set_current_line_highlight(Some(CurrentLineHighlight::None));
|
||||
editor.set_completion_provider(Box::new(
|
||||
editor.set_completion_provider(Some(Box::new(
|
||||
SlashCommandCompletionProvider::new(None, None),
|
||||
));
|
||||
)));
|
||||
if focus {
|
||||
editor.focus(cx);
|
||||
}
|
||||
|
||||
@@ -38,7 +38,10 @@ impl Settings for SlashCommandSettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _cx: &mut AppContext) -> Result<Self> {
|
||||
SettingsSources::<Self::FileContent>::json_merge_with(
|
||||
[sources.default].into_iter().chain(sources.user),
|
||||
[sources.default]
|
||||
.into_iter()
|
||||
.chain(sources.user)
|
||||
.chain(sources.server),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ impl TerminalInlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ impl Settings for AutoUpdateSetting {
|
||||
type FileContent = Option<AutoUpdateSettingContent>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
let auto_update = [sources.release_channel, sources.user]
|
||||
let auto_update = [sources.server, sources.release_channel, sources.user]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
@@ -464,6 +464,7 @@ impl AutoUpdater {
|
||||
smol::fs::create_dir_all(&platform_dir).await.ok();
|
||||
|
||||
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
|
||||
|
||||
if smol::fs::metadata(&version_path).await.is_err() {
|
||||
log::info!("downloading zed-remote-server {os} {arch}");
|
||||
download_remote_server_binary(&version_path, release, client, cx).await?;
|
||||
|
||||
@@ -1178,7 +1178,7 @@ impl Room {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.joined_projects.retain(|project| {
|
||||
if let Some(project) = project.upgrade() {
|
||||
!project.read(cx).is_disconnected()
|
||||
!project.read(cx).is_disconnected(cx)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -58,27 +58,32 @@ struct Args {
|
||||
dev_server_token: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_path_with_position(argument_str: &str) -> Result<String, std::io::Error> {
|
||||
let path = PathWithPosition::parse_str(argument_str);
|
||||
let curdir = env::current_dir()?;
|
||||
|
||||
let canonicalized = path.map_path(|path| match fs::canonicalize(&path) {
|
||||
Ok(path) => Ok(path),
|
||||
Err(e) => {
|
||||
if let Some(mut parent) = path.parent() {
|
||||
if parent == Path::new("") {
|
||||
parent = &curdir
|
||||
fn parse_path_with_position(argument_str: &str) -> anyhow::Result<String> {
|
||||
let canonicalized = match Path::new(argument_str).canonicalize() {
|
||||
Ok(existing_path) => PathWithPosition::from_path(existing_path),
|
||||
Err(_) => {
|
||||
let path = PathWithPosition::parse_str(argument_str);
|
||||
let curdir = env::current_dir().context("reteiving current directory")?;
|
||||
path.map_path(|path| match fs::canonicalize(&path) {
|
||||
Ok(path) => Ok(path),
|
||||
Err(e) => {
|
||||
if let Some(mut parent) = path.parent() {
|
||||
if parent == Path::new("") {
|
||||
parent = &curdir
|
||||
}
|
||||
match fs::canonicalize(parent) {
|
||||
Ok(parent) => Ok(parent.join(path.file_name().unwrap())),
|
||||
Err(_) => Err(e),
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
match fs::canonicalize(parent) {
|
||||
Ok(parent) => Ok(parent.join(path.file_name().unwrap())),
|
||||
Err(_) => Err(e),
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
})?;
|
||||
Ok(canonicalized.to_string(|path| path.display().to_string()))
|
||||
.with_context(|| format!("parsing as path with position {argument_str}"))?,
|
||||
};
|
||||
Ok(canonicalized.to_string(|path| path.to_string_lossy().to_string()))
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
|
||||
@@ -34,8 +34,8 @@ postage.workspace = true
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
rpc = { workspace = true, features = ["gpui"] }
|
||||
rustls.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -141,6 +141,7 @@ impl Settings for ProxySettings {
|
||||
Ok(Self {
|
||||
proxy: sources
|
||||
.user
|
||||
.or(sources.server)
|
||||
.and_then(|value| value.proxy.clone())
|
||||
.or(sources.default.proxy.clone()),
|
||||
})
|
||||
@@ -472,15 +473,21 @@ impl settings::Settings for TelemetrySettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
Ok(Self {
|
||||
diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
diagnostics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.diagnostics)
|
||||
.unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
metrics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.metrics)
|
||||
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
|
||||
})
|
||||
@@ -1023,7 +1030,7 @@ impl Client {
|
||||
&self,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
release_channel: Option<ReleaseChannel>,
|
||||
) -> impl Future<Output = Result<Url>> {
|
||||
) -> impl Future<Output = Result<url::Url>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
let url_override = self.rpc_url.read().clone();
|
||||
|
||||
@@ -1117,7 +1124,7 @@ impl Client {
|
||||
// for us from the RPC URL.
|
||||
//
|
||||
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
|
||||
let mut request = rpc_url.into_client_request()?;
|
||||
let mut request = IntoClientRequest::into_client_request(rpc_url.as_str())?;
|
||||
|
||||
// We then modify the request to add our desired headers.
|
||||
let request_headers = request.headers_mut();
|
||||
@@ -1156,6 +1163,7 @@ impl Client {
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let (stream, _) =
|
||||
async_tungstenite::async_tls::client_async_tls_with_connector(
|
||||
request,
|
||||
|
||||
@@ -32,12 +32,12 @@ clickhouse.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc_http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
@@ -48,6 +48,7 @@ prometheus = "0.13"
|
||||
prost.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
reqwest_client.workspace = true
|
||||
rpc.workspace = true
|
||||
rustc-demangle.workspace = true
|
||||
scrypt = "0.11"
|
||||
@@ -66,7 +67,7 @@ telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
toml.workspace = true
|
||||
tower = "0.4"
|
||||
tower-http = { workspace = true, features = ["trace"] }
|
||||
|
||||
@@ -199,6 +199,12 @@ spec:
|
||||
secretKeyRef:
|
||||
name: slack
|
||||
key: panics_webhook
|
||||
- name: STRIPE_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: stripe
|
||||
key: api_key
|
||||
optional: true
|
||||
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
|
||||
value: "1000"
|
||||
- name: SUPERMAVEN_ADMIN_API_KEY
|
||||
|
||||
@@ -422,6 +422,15 @@ CREATE TABLE dev_server_projects (
|
||||
paths TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
create table if not exists billing_preferences (
|
||||
id serial primary key,
|
||||
created_at timestamp without time zone not null default now(),
|
||||
user_id integer not null references users(id) on delete cascade,
|
||||
max_monthly_llm_usage_spending_in_cents integer not null
|
||||
);
|
||||
|
||||
create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id);
|
||||
@@ -0,0 +1,12 @@
|
||||
create table billing_events (
|
||||
id serial primary key,
|
||||
idempotency_key uuid not null default gen_random_uuid(),
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
input_tokens bigint not null default 0,
|
||||
input_cache_creation_tokens bigint not null default 0,
|
||||
input_cache_read_tokens bigint not null default 0,
|
||||
output_tokens bigint not null default 0
|
||||
);
|
||||
|
||||
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
|
||||
@@ -1,7 +1,3 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use axum::{
|
||||
extract::{self, Query},
|
||||
@@ -9,32 +5,43 @@ use axum::{
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use collections::HashSet;
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use stripe::{
|
||||
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
||||
Subscription, SubscriptionId, SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
|
||||
use crate::db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::{
|
||||
db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
use crate::{
|
||||
db::{billing_subscription::StripeSubscriptionStatus, UserId},
|
||||
llm::db::LlmDatabase,
|
||||
};
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS;
|
||||
use crate::rpc::ResultExt as _;
|
||||
use crate::{AppState, Error, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/billing/preferences",
|
||||
get(get_billing_preferences).put(update_billing_preferences),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions",
|
||||
get(list_billing_subscriptions).post(create_billing_subscription),
|
||||
@@ -45,6 +52,85 @@ pub fn router() -> Router {
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetBillingPreferencesParams {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn get_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetBillingPreferencesParams>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(params.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let preferences = app.db.get_billing_preferences(user.id).await?;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpdateBillingPreferencesBody {
|
||||
github_user_id: i32,
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn update_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let billing_preferences =
|
||||
if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
|
||||
app.db
|
||||
.update_billing_preferences(
|
||||
user.id,
|
||||
&UpdateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
body.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
},
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
app.db
|
||||
.create_billing_preferences(
|
||||
user.id,
|
||||
&crate::db::CreateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: body
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListBillingSubscriptionsParams {
|
||||
github_user_id: i32,
|
||||
@@ -117,12 +203,22 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some((stripe_client, stripe_price_id)) = app
|
||||
.stripe_client
|
||||
.clone()
|
||||
.zip(app.config.stripe_llm_usage_price_id.clone())
|
||||
else {
|
||||
log::error!("failed to retrieve Stripe client or price ID");
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
@@ -146,26 +242,14 @@ async fn create_billing_subscription(
|
||||
customer.id
|
||||
};
|
||||
|
||||
let checkout_session = {
|
||||
let mut params = CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(user.github_login.as_str());
|
||||
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
|
||||
price: Some(stripe_price_id.to_string()),
|
||||
quantity: Some(0),
|
||||
..Default::default()
|
||||
}]);
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
params.success_url = Some(&success_url);
|
||||
|
||||
CheckoutSession::create(&stripe_client, params).await?
|
||||
};
|
||||
|
||||
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
let checkout_session_url = stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?;
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
checkout_session_url: checkout_session
|
||||
.url
|
||||
.ok_or_else(|| anyhow!("no checkout session URL"))?,
|
||||
checkout_session_url,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -320,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
||||
|
||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
@@ -331,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &stripe_client).await.log_err();
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
executor.sleep(POLL_EVENTS_INTERVAL).await;
|
||||
}
|
||||
@@ -341,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
|
||||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
@@ -457,7 +544,7 @@ async fn poll_stripe_events(
|
||||
| EventType::CustomerSubscriptionPaused
|
||||
| EventType::CustomerSubscriptionResumed
|
||||
| EventType::CustomerSubscriptionDeleted => {
|
||||
handle_customer_subscription_event(app, stripe_client, event).await
|
||||
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
|
||||
}
|
||||
_ => Ok(()),
|
||||
};
|
||||
@@ -525,6 +612,7 @@ async fn handle_customer_event(
|
||||
|
||||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -570,6 +658,12 @@ async fn handle_customer_subscription_event(
|
||||
.await?;
|
||||
}
|
||||
|
||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||
// to either grant/revoke access.
|
||||
rpc_server
|
||||
.refresh_llm_tokens_for_user(billing_customer.user_id)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -635,15 +729,15 @@ async fn find_or_create_billing_customer(
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
|
||||
log::warn!("failed to retrieve Stripe LLM usage price ID");
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -652,15 +746,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(
|
||||
&app,
|
||||
&llm_db,
|
||||
&stripe_client,
|
||||
stripe_llm_usage_price_id.clone(),
|
||||
)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
sync_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.trace_err();
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
@@ -669,60 +757,44 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
||||
|
||||
async fn sync_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: Arc<str>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let subscriptions = app.db.get_active_billing_subscriptions().await?;
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
.map(|(event, _)| event.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||
|
||||
for (customer, subscription) in subscriptions {
|
||||
update_stripe_subscription(
|
||||
llm_db,
|
||||
stripe_client,
|
||||
&stripe_llm_usage_price_id,
|
||||
customer,
|
||||
subscription,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
for (event, model) in events {
|
||||
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||
stripe_subscriptions.get(&event.user_id)
|
||||
else {
|
||||
tracing::warn!(
|
||||
user_id = event.user_id.0,
|
||||
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||
.stripe_subscription_id
|
||||
.parse()
|
||||
.context("failed to parse stripe subscription id from db")?;
|
||||
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||
.stripe_customer_id
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_stripe_subscription(
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: &Arc<str>,
|
||||
customer: billing_customer::Model,
|
||||
subscription: billing_subscription::Model,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let monthly_spending = llm_db
|
||||
.get_user_spending_for_month(customer.user_id, Utc::now())
|
||||
.await?;
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
let monthly_spending_over_free_tier =
|
||||
monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS);
|
||||
|
||||
let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil();
|
||||
Subscription::update(
|
||||
stripe_client,
|
||||
&subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||
// TODO: Do we need to send up the `id` if a subscription item
|
||||
// with this price already exists, or will Stripe take care of
|
||||
// it?
|
||||
id: None,
|
||||
price: Some(stripe_llm_usage_price_id.to_string()),
|
||||
quantity: Some(new_quantity as u64),
|
||||
..Default::default()
|
||||
}]),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -670,7 +670,6 @@ pub struct EditorEventRow {
|
||||
time: i64,
|
||||
copilot_enabled: bool,
|
||||
copilot_enabled_for_language: bool,
|
||||
historical_event: bool,
|
||||
architecture: String,
|
||||
is_staff: Option<bool>,
|
||||
major: Option<i32>,
|
||||
@@ -718,7 +717,6 @@ impl EditorEventRow {
|
||||
country_code: country_code.unwrap_or("XX".to_string()),
|
||||
region_code: "".to_string(),
|
||||
city: "".to_string(),
|
||||
historical_event: false,
|
||||
is_via_ssh: event.is_via_ssh,
|
||||
}
|
||||
}
|
||||
|
||||
78
crates/collab/src/cents.rs
Normal file
78
crates/collab/src/cents.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
/// A number of cents.
|
||||
#[derive(
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Hash,
|
||||
Clone,
|
||||
Copy,
|
||||
derive_more::Add,
|
||||
derive_more::AddAssign,
|
||||
)]
|
||||
pub struct Cents(pub u32);
|
||||
|
||||
impl Cents {
|
||||
pub const ZERO: Self = Self(0);
|
||||
|
||||
pub const fn new(cents: u32) -> Self {
|
||||
Self(cents)
|
||||
}
|
||||
|
||||
pub const fn from_dollars(dollars: u32) -> Self {
|
||||
Self(dollars * 100)
|
||||
}
|
||||
|
||||
pub fn saturating_sub(self, other: Cents) -> Self {
|
||||
Self(self.0.saturating_sub(other.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cents_new() {
|
||||
assert_eq!(Cents::new(50), Cents(50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_from_dollars() {
|
||||
assert_eq!(Cents::from_dollars(1), Cents(100));
|
||||
assert_eq!(Cents::from_dollars(5), Cents(500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_zero() {
|
||||
assert_eq!(Cents::ZERO, Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add() {
|
||||
assert_eq!(Cents(50) + Cents(30), Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add_assign() {
|
||||
let mut cents = Cents(50);
|
||||
cents += Cents(30);
|
||||
assert_eq!(cents, Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_saturating_sub() {
|
||||
assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20));
|
||||
assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_ordering() {
|
||||
assert!(Cents(50) > Cents(30));
|
||||
assert!(Cents(30) < Cents(50));
|
||||
assert_eq!(Cents(50), Cents(50));
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,9 @@ pub use tests::TestDb;
|
||||
|
||||
pub use ids::*;
|
||||
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
|
||||
pub use queries::billing_preferences::{
|
||||
CreateBillingPreferencesParams, UpdateBillingPreferencesParams,
|
||||
};
|
||||
pub use queries::billing_subscriptions::{
|
||||
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
|
||||
@@ -72,6 +72,7 @@ macro_rules! id_type {
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BillingCustomerId);
|
||||
id_type!(BillingSubscriptionId);
|
||||
id_type!(BillingPreferencesId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
|
||||
@@ -2,6 +2,7 @@ use super::*;
|
||||
|
||||
pub mod access_tokens;
|
||||
pub mod billing_customers;
|
||||
pub mod billing_preferences;
|
||||
pub mod billing_subscriptions;
|
||||
pub mod buffers;
|
||||
pub mod channels;
|
||||
|
||||
75
crates/collab/src/db/queries/billing_preferences.rs
Normal file
75
crates/collab/src/db/queries/billing_preferences.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Returns the billing preferences for the given user, if they exist.
|
||||
pub async fn get_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Creates new billing preferences for the given user.
|
||||
pub async fn create_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &CreateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
params.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the billing preferences for the given user.
|
||||
pub async fn update_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &UpdateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::update_many()
|
||||
.set(billing_preference::ActiveModel {
|
||||
max_monthly_llm_usage_spending_in_cents: params
|
||||
.max_monthly_llm_usage_spending_in_cents
|
||||
.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("billing preferences not found"))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -114,23 +114,31 @@ impl Database {
|
||||
|
||||
pub async fn get_active_billing_subscriptions(
|
||||
&self,
|
||||
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut result = Vec::new();
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
result.push((customer, subscription));
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -838,6 +838,7 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect(),
|
||||
dev_server_project_id: project.dev_server_project_id,
|
||||
|
||||
@@ -718,6 +718,7 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod access_token;
|
||||
pub mod billing_customer;
|
||||
pub mod billing_preference;
|
||||
pub mod billing_subscription;
|
||||
pub mod buffer;
|
||||
pub mod buffer_operation;
|
||||
|
||||
30
crates/collab/src/db/tables/billing_preference.rs
Normal file
30
crates/collab/src/db/tables/billing_preference.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use crate::db::{BillingPreferencesId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_preferences")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingPreferencesId,
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
mod cents;
|
||||
pub mod clickhouse;
|
||||
pub mod db;
|
||||
pub mod env;
|
||||
@@ -9,6 +10,7 @@ pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -20,13 +22,17 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
pub use cents::*;
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
@@ -174,7 +180,6 @@ pub struct Config {
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub stripe_llm_usage_price_id: Option<Arc<str>>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -194,7 +199,7 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn is_llm_billing_enabled(&self) -> bool {
|
||||
self.stripe_llm_usage_price_id.is_some()
|
||||
self.stripe_api_key.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -235,7 +240,6 @@ impl Config {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_llm_usage_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
}
|
||||
@@ -268,9 +272,11 @@ impl ServiceMode {
|
||||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub clickhouse_client: Option<::clickhouse::Client>,
|
||||
@@ -284,6 +290,20 @@ impl AppState {
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
.llm_database_url
|
||||
.clone()
|
||||
.zip(config.llm_database_max_connections)
|
||||
{
|
||||
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||
llm_db_options.max_connections(llm_database_max_connections);
|
||||
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||
llm_db.initialize().await?;
|
||||
Some(Arc::new(llm_db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let live_kit_client = if let Some(((server, key), secret)) = config
|
||||
.live_kit_server
|
||||
.as_ref()
|
||||
@@ -300,11 +320,16 @@ impl AppState {
|
||||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
live_kit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
clickhouse_client: config
|
||||
@@ -317,12 +342,11 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
let api_key = config
|
||||
.stripe_api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
|
||||
|
||||
Ok(stripe::Client::new(api_key))
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ mod telemetry;
|
||||
mod token;
|
||||
|
||||
use crate::{
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
|
||||
Config, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
@@ -20,13 +20,14 @@ use axum::{
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use isahc_http_client::IsahcHttpClient;
|
||||
use rpc::ListModelsResponse;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
@@ -43,7 +44,7 @@ pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub http_client: IsahcHttpClient,
|
||||
pub http_client: ReqwestClient,
|
||||
pub clickhouse_client: Option<clickhouse::Client>,
|
||||
active_user_count_by_model:
|
||||
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
|
||||
@@ -69,11 +70,8 @@ impl LlmState {
|
||||
let db = Arc::new(db);
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = IsahcHttpClient::builder()
|
||||
.default_header("User-Agent", user_agent)
|
||||
.build()
|
||||
.map(IsahcHttpClient::from)
|
||||
.context("failed to construct http client")?;
|
||||
let http_client =
|
||||
ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?;
|
||||
|
||||
let this = Self {
|
||||
executor,
|
||||
@@ -418,10 +416,7 @@ async fn perform_completion(
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
tokens: TokenUsage::default(),
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
@@ -438,13 +433,18 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach before they have to pay.
|
||||
pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100;
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);
|
||||
|
||||
/// The default value to use for maximum spend per month if the user did not
|
||||
/// explicitly set a maximum spend.
|
||||
///
|
||||
/// Used to prevent surprise bills.
|
||||
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
|
||||
|
||||
/// The maximum lifetime spending an individual user can reach before being cut off.
|
||||
///
|
||||
/// Represented in cents.
|
||||
const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100;
|
||||
const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
@@ -464,18 +464,31 @@ async fn check_usage_limit(
|
||||
.await?;
|
||||
|
||||
if state.config.is_llm_billing_enabled() {
|
||||
if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS {
|
||||
if !claims.has_llm_subscription.unwrap_or(false) {
|
||||
if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
|
||||
if !claims.has_llm_subscription {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Remove this once we've rolled out monthly spending limits.
|
||||
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
|
||||
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT {
|
||||
return Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached.".to_string(),
|
||||
@@ -593,10 +606,7 @@ struct TokenCountingStream<S> {
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
tokens: TokenUsage,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
@@ -610,10 +620,10 @@ where
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.input_tokens += chunk.input_tokens;
|
||||
self.output_tokens += chunk.output_tokens;
|
||||
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
|
||||
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
@@ -626,13 +636,11 @@ where
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let is_llm_billing_enabled = state.config.is_llm_billing_enabled();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let input_token_count = self.input_tokens;
|
||||
let output_token_count = self.output_tokens;
|
||||
let cache_creation_input_token_count = self.cache_creation_input_tokens;
|
||||
let cache_read_input_token_count = self.cache_read_input_tokens;
|
||||
let tokens = self.tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
@@ -641,10 +649,16 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
input_token_count,
|
||||
cache_creation_input_token_count,
|
||||
cache_read_input_token_count,
|
||||
output_token_count,
|
||||
tokens,
|
||||
// We're passing `false` here if LLM billing is not enabled
|
||||
// so that we don't write any records to the
|
||||
// `billing_events` table until we're ready to bill users.
|
||||
if is_llm_billing_enabled {
|
||||
claims.has_llm_subscription
|
||||
} else {
|
||||
false
|
||||
},
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
@@ -674,24 +688,25 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
},
|
||||
model,
|
||||
provider: provider.to_string(),
|
||||
input_token_count: input_token_count as u64,
|
||||
cache_creation_input_token_count: cache_creation_input_token_count
|
||||
as u64,
|
||||
cache_read_input_token_count: cache_read_input_token_count as u64,
|
||||
output_token_count: output_token_count as u64,
|
||||
input_token_count: tokens.input as u64,
|
||||
cache_creation_input_token_count: tokens.input_cache_creation as u64,
|
||||
cache_read_input_token_count: tokens.input_cache_read as u64,
|
||||
output_token_count: tokens.output as u64,
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
||||
input_tokens_this_month: usage.tokens_this_month.input as u64,
|
||||
cache_creation_input_tokens_this_month: usage
|
||||
.cache_creation_input_tokens_this_month
|
||||
.tokens_this_month
|
||||
.input_cache_creation
|
||||
as u64,
|
||||
cache_read_input_tokens_this_month: usage
|
||||
.cache_read_input_tokens_this_month
|
||||
.tokens_this_month
|
||||
.input_cache_read
|
||||
as u64,
|
||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
||||
spending_this_month: usage.spending_this_month as u64,
|
||||
lifetime_spending: usage.lifetime_spending as u64,
|
||||
output_tokens_this_month: usage.tokens_this_month.output as u64,
|
||||
spending_this_month: usage.spending_this_month.0 as u64,
|
||||
lifetime_spending: usage.lifetime_spending.0 as u64,
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::ActiveUserCount;
|
||||
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||
use sea_orm::prelude::*;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::{
|
||||
|
||||
@@ -3,8 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(BillingEventId);
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
||||
31
crates/collab/src/llm/db/queries/billing_events.rs
Normal file
31
crates/collab/src/llm/db/queries/billing_events.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use super::*;
|
||||
use crate::Result;
|
||||
use anyhow::Context as _;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let events_with_models = billing_event::Entity::find()
|
||||
.find_also_related(model::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
events_with_models
|
||||
.into_iter()
|
||||
.map(|(event, model)| {
|
||||
let model =
|
||||
model.context("could not find model associated with billing event")?;
|
||||
Ok((event, model))
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::db::UserId;
|
||||
use crate::llm::Cents;
|
||||
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use chrono::{Datelike, Duration};
|
||||
use futures::StreamExt as _;
|
||||
use rpc::LanguageModelProvider;
|
||||
@@ -8,17 +9,28 @@ use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
pub input_cache_read: usize,
|
||||
pub output: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total(&self) -> usize {
|
||||
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub input_tokens_this_month: usize,
|
||||
pub cache_creation_input_tokens_this_month: usize,
|
||||
pub cache_read_input_tokens_this_month: usize,
|
||||
pub output_tokens_this_month: usize,
|
||||
pub spending_this_month: usize,
|
||||
pub lifetime_spending: usize,
|
||||
pub tokens_this_month: TokenUsage,
|
||||
pub spending_this_month: Cents,
|
||||
pub lifetime_spending: Cents,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
@@ -144,7 +156,7 @@ impl LlmDatabase {
|
||||
&self,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<usize> {
|
||||
) -> Result<Cents> {
|
||||
self.transaction(|tx| async move {
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
@@ -158,7 +170,7 @@ impl LlmDatabase {
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
let mut monthly_spending_in_cents = 0;
|
||||
let mut monthly_spending = Cents::ZERO;
|
||||
|
||||
while let Some(usage) = monthly_usages.next().await {
|
||||
let usage = usage?;
|
||||
@@ -166,7 +178,7 @@ impl LlmDatabase {
|
||||
continue;
|
||||
};
|
||||
|
||||
monthly_spending_in_cents += calculate_spending(
|
||||
monthly_spending += calculate_spending(
|
||||
model,
|
||||
usage.input_tokens as usize,
|
||||
usage.cache_creation_input_tokens as usize,
|
||||
@@ -175,7 +187,7 @@ impl LlmDatabase {
|
||||
);
|
||||
}
|
||||
|
||||
Ok(monthly_spending_in_cents)
|
||||
Ok(monthly_spending)
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -238,7 +250,7 @@ impl LlmDatabase {
|
||||
monthly_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
Cents::ZERO
|
||||
};
|
||||
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
|
||||
calculate_spending(
|
||||
@@ -249,25 +261,27 @@ impl LlmDatabase {
|
||||
lifetime_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
0
|
||||
Cents::ZERO
|
||||
};
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
cache_creation_input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
cache_read_input_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output_tokens_this_month: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
input_cache_creation: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
input_cache_read: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -282,10 +296,9 @@ impl LlmDatabase {
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
input_token_count: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
output_token_count: usize,
|
||||
tokens: TokenUsage,
|
||||
has_llm_subscription: bool,
|
||||
max_monthly_spend: Cents,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -312,10 +325,6 @@ impl LlmDatabase {
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let total_token_count = input_token_count
|
||||
+ cache_read_input_tokens
|
||||
+ cache_creation_input_tokens
|
||||
+ output_token_count;
|
||||
let tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
@@ -324,7 +333,7 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
total_token_count,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
@@ -336,7 +345,7 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
total_token_count,
|
||||
tokens.total(),
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
@@ -360,18 +369,14 @@ impl LlmDatabase {
|
||||
Some(usage) => {
|
||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
@@ -383,12 +388,12 @@ impl LlmDatabase {
|
||||
model_id: ActiveValue::set(model.id),
|
||||
month: ActiveValue::set(month),
|
||||
year: ActiveValue::set(year),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
cache_creation_input_tokens as i64,
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
@@ -404,6 +409,27 @@ impl LlmDatabase {
|
||||
monthly_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
if !is_staff
|
||||
&& spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
|
||||
&& has_llm_subscription
|
||||
&& spending_this_month <= max_monthly_spend
|
||||
{
|
||||
billing_event::ActiveModel {
|
||||
id: ActiveValue::not_set(),
|
||||
idempotency_key: ActiveValue::not_set(),
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_cache_creation_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
@@ -418,18 +444,14 @@ impl LlmDatabase {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
@@ -439,12 +461,12 @@ impl LlmDatabase {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
cache_creation_input_tokens as i64,
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
@@ -464,11 +486,12 @@ impl LlmDatabase {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
input_tokens_this_month: monthly_usage.input_tokens as usize,
|
||||
cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
|
||||
as usize,
|
||||
cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
|
||||
output_tokens_this_month: monthly_usage.output_tokens as usize,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage.input_tokens as usize,
|
||||
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||
output: monthly_usage.output_tokens as usize,
|
||||
},
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -637,7 +660,7 @@ fn calculate_spending(
|
||||
cache_creation_input_tokens_this_month: usize,
|
||||
cache_read_input_tokens_this_month: usize,
|
||||
output_tokens_this_month: usize,
|
||||
) -> usize {
|
||||
) -> Cents {
|
||||
let input_token_cost =
|
||||
input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000;
|
||||
let cache_creation_input_token_cost = cache_creation_input_tokens_this_month
|
||||
@@ -648,10 +671,11 @@ fn calculate_spending(
|
||||
/ 1_000_000;
|
||||
let output_token_cost =
|
||||
output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000;
|
||||
input_token_cost
|
||||
let spending = input_token_cost
|
||||
+ cache_creation_input_token_cost
|
||||
+ cache_read_input_token_cost
|
||||
+ output_token_cost
|
||||
+ output_token_cost;
|
||||
Cents::new(spending as u32)
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod billing_event;
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
|
||||
37
crates/collab/src/llm/db/tables/billing_event.rs
Normal file
37
crates/collab/src/llm/db/tables/billing_event.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{BillingEventId, ModelId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingEventId,
|
||||
pub idempotency_key: Uuid,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub input_cache_creation_tokens: i64,
|
||||
pub input_cache_read_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -29,6 +29,8 @@ pub enum Relation {
|
||||
Provider,
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
#[sea_orm(has_many = "super::billing_event::Entity")]
|
||||
BillingEvents,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
@@ -43,4 +45,10 @@ impl Related<super::usage::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_event::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingEvents.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod billing_tests;
|
||||
mod provider_tests;
|
||||
mod usage_tests;
|
||||
|
||||
|
||||
158
crates/collab/src/llm/db/tests/billing_tests.rs
Normal file
158
crates/collab/src/llm/db/tests/billing_tests.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::{
|
||||
db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase, TokenUsage,
|
||||
},
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
},
|
||||
test_llm_db, Cents,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(
|
||||
test_billing_limit_exceeded,
|
||||
test_billing_limit_exceeded_postgres
|
||||
);
|
||||
|
||||
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "fake-claude-limerick";
|
||||
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
|
||||
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
|
||||
|
||||
// Initialize the database and insert the model
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
|
||||
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Set a fixed datetime for consistent testing
|
||||
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let max_monthly_spend = Cents::from_dollars(11);
|
||||
|
||||
// Record usage that brings us close to the limit but doesn't exceed it
|
||||
// Let's say we use $10.50 worth of tokens
|
||||
let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens
|
||||
let usage = TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// Verify that before we record any usage, there are 0 billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 0);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the recorded usage and spending
|
||||
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
|
||||
// Verify that we exceeded the free tier usage
|
||||
assert!(
|
||||
recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
"Expected spending to exceed free tier limit"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
recorded_usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: tokens_to_use,
|
||||
tokens_this_day: tokens_to_use,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::new(1050),
|
||||
lifetime_spending: Cents::new(1050),
|
||||
}
|
||||
);
|
||||
|
||||
// Verify that there is one `billing_event` record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
let (billing_event, _model) = &billing_events[0];
|
||||
assert_eq!(billing_event.user_id, user_id);
|
||||
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
|
||||
assert_eq!(billing_event.input_cache_creation_tokens, 0);
|
||||
assert_eq!(billing_event.input_cache_read_tokens, 0);
|
||||
assert_eq!(billing_event.output_tokens, 0);
|
||||
|
||||
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $10.50 to $11.50, which is over the $11 monthly maximum limit
|
||||
let usage_exceeding = TokenUsage {
|
||||
input: tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// This should still create a billing event as it's the first request that exceeds the limit
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_exceeding,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that there is still one billing record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
updated_usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: tokens_to_use + tokens_to_exceed,
|
||||
tokens_this_day: tokens_to_use + tokens_to_exceed,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: tokens_to_use + tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::new(1150),
|
||||
lifetime_spending: Cents::new(1150),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2,9 +2,9 @@ use crate::{
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase,
|
||||
LlmDatabase, TokenUsage,
|
||||
},
|
||||
test_llm_db,
|
||||
test_llm_db, Cents,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
@@ -36,14 +36,42 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -52,12 +80,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
tokens_this_day: 3000,
|
||||
input_tokens_this_month: 3000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -69,19 +99,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
tokens_this_day: 3000,
|
||||
input_tokens_this_month: 3000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -90,12 +136,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
tokens_this_day: 6000,
|
||||
input_tokens_this_month: 6000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -108,18 +156,34 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
input_tokens_this_month: 6000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 4000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -128,12 +192,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
tokens_this_day: 9000,
|
||||
input_tokens_this_month: 10000,
|
||||
cache_creation_input_tokens_this_month: 0,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 10000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -143,9 +209,23 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
.with_timezone(&Utc);
|
||||
|
||||
// Test cache creation input tokens
|
||||
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -154,19 +234,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 1500,
|
||||
tokens_this_day: 1500,
|
||||
input_tokens_this_month: 1000,
|
||||
cache_creation_input_tokens_this_month: 500,
|
||||
cache_read_input_tokens_this_month: 0,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// Test cache read input tokens
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -175,12 +271,14 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 2800,
|
||||
tokens_this_day: 2800,
|
||||
input_tokens_this_month: 2000,
|
||||
cache_creation_input_tokens_this_month: 500,
|
||||
cache_read_input_tokens_this_month: 300,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use crate::{db::UserId, Config};
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::{
|
||||
db::{billing_preference, UserId},
|
||||
Config,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
@@ -16,22 +20,20 @@ pub struct LlmTokenClaims {
|
||||
pub github_user_login: String,
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
// This field is temporarily optional so it can be added
|
||||
// in a backwards-compatible way. We can make it required
|
||||
// once all of the LLM tokens have cycled (~1 hour after
|
||||
// this change has been deployed).
|
||||
#[serde(default)]
|
||||
pub has_llm_subscription: Option<bool>,
|
||||
pub has_llm_subscription: bool,
|
||||
pub max_monthly_spend_in_cents: u32,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn create(
|
||||
user_id: UserId,
|
||||
github_user_login: String,
|
||||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
has_llm_closed_beta_feature_flag: bool,
|
||||
has_llm_subscription: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
@@ -51,7 +53,11 @@ impl LlmTokenClaims {
|
||||
github_user_login,
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
has_llm_subscription: Some(has_llm_subscription),
|
||||
has_llm_subscription,
|
||||
max_monthly_spend_in_cents: billing_preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents as u32
|
||||
}),
|
||||
plan,
|
||||
};
|
||||
|
||||
|
||||
@@ -111,6 +111,13 @@ async fn main() -> Result<()> {
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
if let Some(stripe_billing) = state.stripe_billing.clone() {
|
||||
let executor = state.executor.clone();
|
||||
executor.spawn_detached(async move {
|
||||
stripe_billing.initialize().await.trace_err();
|
||||
});
|
||||
}
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
@@ -125,6 +132,8 @@ async fn main() -> Result<()> {
|
||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::api::routes(rpc_server.clone()))
|
||||
.merge(collab::rpc::routes(rpc_server.clone()));
|
||||
@@ -133,7 +142,6 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
if mode.is_api() {
|
||||
poll_stripe_events_periodically(state.clone());
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
@@ -155,8 +163,9 @@ async fn main() -> Result<()> {
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
if let Some(llm_db) = llm_db {
|
||||
sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_usage_with_stripe_periodically(state.clone());
|
||||
}
|
||||
|
||||
app = app
|
||||
|
||||
@@ -36,8 +36,8 @@ use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use http_client::HttpClient;
|
||||
use isahc_http_client::IsahcHttpClient;
|
||||
use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use sha2::Digest;
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
|
||||
@@ -469,9 +469,6 @@ impl Server {
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskContextForLocation>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskTemplates>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_read_only_project_request::<proto::GetHover>,
|
||||
))
|
||||
@@ -964,8 +961,8 @@ impl Server {
|
||||
tracing::info!("connection opened");
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
|
||||
Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
|
||||
let http_client = match ReqwestClient::user_agent(&user_agent) {
|
||||
Ok(http_client) => Arc::new(http_client),
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to create HTTP client");
|
||||
return;
|
||||
@@ -1221,6 +1218,15 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, proto::RefreshLlmToken {})
|
||||
.trace_err();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
||||
ServerSnapshot {
|
||||
connection_pool: ConnectionPoolGuard {
|
||||
@@ -4920,10 +4926,14 @@ async fn get_llm_api_token(
|
||||
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
user.id,
|
||||
user.github_login.clone(),
|
||||
session.is_staff(),
|
||||
billing_preferences,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
session.has_llm_subscription(&db).await?,
|
||||
session.current_plan(&db).await?,
|
||||
|
||||
469
crates/collab/src/stripe_billing.rs
Normal file
469
crates/collab/src/stripe_billing.rs
Normal file
@@ -0,0 +1,469 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{llm, Cents, Result};
|
||||
use anyhow::Context;
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<stripe::Client>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
}
|
||||
|
||||
pub struct StripeModel {
|
||||
input_tokens_price: StripeBillingPrice,
|
||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||
input_cache_read_tokens_price: StripeBillingPrice,
|
||||
output_tokens_price: StripeBillingPrice,
|
||||
}
|
||||
|
||||
struct StripeBillingPrice {
|
||||
id: stripe::PriceId,
|
||||
meter_event_name: String,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
log::info!("StripeBilling: initializing");
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let (meters, prices) = futures::try_join!(
|
||||
StripeMeter::list(&self.client),
|
||||
stripe::Price::list(
|
||||
&self.client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
)?;
|
||||
|
||||
for meter in meters.data {
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
|
||||
for price in prices.data {
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("StripeBilling: initialized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
&format!("{} (Input Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_creation_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_read_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/output_tokens", model.id),
|
||||
&format!("{} (Output Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_output_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
Ok(StripeModel {
|
||||
input_tokens_price,
|
||||
input_cache_creation_tokens_price,
|
||||
input_cache_read_tokens_price,
|
||||
output_tokens_price,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_insert_price(
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
price_per_million_tokens: Cents,
|
||||
) -> Result<StripeBillingPrice> {
|
||||
// Fast code path when the meter and the price already exist.
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
return Ok(StripeBillingPrice {
|
||||
id: price_id.clone(),
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
|
||||
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
active: Some(true),
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.price_ids_by_meter_id
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
|
||||
Ok(StripeBillingPrice {
|
||||
id: price_id,
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_model(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
model: &StripeModel,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
|
||||
{
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_read_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.output_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !items.is_empty() {
|
||||
items.extend(subscription.items.data.iter().map(|item| {
|
||||
stripe::UpdateSubscriptionItems {
|
||||
id: Some(item.id.to_string()),
|
||||
..Default::default()
|
||||
}
|
||||
}));
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(items),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
model: &StripeModel,
|
||||
event: &llm::db::billing_event::Model,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
|
||||
if event.input_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_creation_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_creation_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_read_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_read_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_read_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.output_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("output_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.output_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.output_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
model: &StripeModel,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(
|
||||
[
|
||||
&model.input_tokens_price.id,
|
||||
&model.input_cache_creation_tokens_price.id,
|
||||
&model.input_cache_read_tokens_price.id,
|
||||
&model.output_tokens_price.id,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(price_id.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DefaultAggregation {
|
||||
formula: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterParams<'a> {
|
||||
default_aggregation: DefaultAggregation,
|
||||
display_name: String,
|
||||
event_name: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct StripeMeter {
|
||||
id: String,
|
||||
event_name: String,
|
||||
}
|
||||
|
||||
impl StripeMeter {
|
||||
pub fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterParams,
|
||||
) -> stripe::Response<Self> {
|
||||
client.post_form("/billing/meters", params)
|
||||
}
|
||||
|
||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
client.get_query("/billing/meters", Params { limit: Some(100) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StripeMeterEvent {
|
||||
identifier: String,
|
||||
}
|
||||
|
||||
impl StripeMeterEvent {
|
||||
pub async fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterEventParams<'_>,
|
||||
) -> Result<Self, stripe::StripeError> {
|
||||
let identifier = params.identifier;
|
||||
match client.post_form("/billing/meter_events", params).await {
|
||||
Ok(event) => Ok(event),
|
||||
Err(stripe::StripeError::Stripe(error)) => {
|
||||
if error.http_status == 400
|
||||
&& error
|
||||
.message
|
||||
.as_ref()
|
||||
.map_or(false, |message| message.contains(identifier))
|
||||
{
|
||||
Ok(Self {
|
||||
identifier: identifier.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(stripe::StripeError::Stripe(error))
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventParams<'a> {
|
||||
identifier: &'a str,
|
||||
event_name: &'a str,
|
||||
payload: StripeCreateMeterEventPayload<'a>,
|
||||
timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventPayload<'a> {
|
||||
value: u64,
|
||||
stripe_customer_id: &'a stripe::CustomerId,
|
||||
}
|
||||
|
||||
fn subscription_contains_price(
|
||||
subscription: &stripe::Subscription,
|
||||
price_id: &stripe::PriceId,
|
||||
) -> bool {
|
||||
subscription.items.data.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id == *price_id)
|
||||
})
|
||||
}
|
||||
@@ -50,7 +50,7 @@ async fn test_channel_guests(
|
||||
project_b.read_with(cx_b, |project, _| project.remote_id()),
|
||||
Some(project_id),
|
||||
);
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b
|
||||
.update(cx_b, |project, cx| {
|
||||
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
|
||||
@@ -103,7 +103,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
workspace.active_item_as::<Editor>(cx).unwrap(),
|
||||
)
|
||||
});
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(editor_b.update(cx_b, |e, cx| e.read_only(cx)));
|
||||
assert!(room_b.read_with(cx_b, |room, _| !room.can_use_microphone()));
|
||||
assert!(room_b
|
||||
@@ -127,7 +127,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are now editable
|
||||
assert!(project_b.read_with(cx_b, |project, _| !project.is_read_only()));
|
||||
assert!(project_b.read_with(cx_b, |project, cx| !project.is_read_only(cx)));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| !editor.read_only(cx)));
|
||||
|
||||
// B sees themselves as muted, and can unmute.
|
||||
@@ -153,7 +153,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are no longer editable
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| editor.read_only(cx)));
|
||||
assert!(room_b
|
||||
.update(cx_b, |room, cx| room.share_microphone(cx))
|
||||
|
||||
@@ -262,7 +262,7 @@ async fn test_dev_server_leave_room(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -308,7 +308,7 @@ async fn test_dev_server_delete(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
|
||||
cx1.update(|cx| {
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
@@ -418,12 +418,12 @@ async fn test_dev_server_refresh_access_token(
|
||||
|
||||
// Assert that the other client was disconnected
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
|
||||
// Assert that the owner of the dev server does not see the dev server as online anymore
|
||||
let (workspace, cx1) = client1.active_workspace(cx1);
|
||||
cx1.update(|cx| {
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected());
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected(cx));
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
assert_eq!(
|
||||
store.dev_servers().first().unwrap().status,
|
||||
|
||||
@@ -114,7 +114,7 @@ async fn test_host_disconnect(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_b.read_with(cx_b, |project, _| project.is_read_only());
|
||||
project_b.read_with(cx_b, |project, cx| project.is_read_only(cx));
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
@@ -379,75 +379,51 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu
|
||||
.next()
|
||||
.await
|
||||
.unwrap();
|
||||
cx_a.executor().finish_waiting();
|
||||
|
||||
// Open the buffer on the host.
|
||||
let buffer_a = project_a
|
||||
.update(cx_a, |p, cx| p.open_buffer((worktree_id, "main.rs"), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx_a.executor().run_until_parked();
|
||||
|
||||
buffer_a.read_with(cx_a, |buffer, _| {
|
||||
assert_eq!(buffer.text(), "fn main() { a. }")
|
||||
});
|
||||
|
||||
// Confirm a completion on the guest.
|
||||
editor_b.update(cx_b, |editor, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
editor.confirm_completion(&ConfirmCompletion { item_ix: Some(0) }, cx);
|
||||
assert_eq!(editor.text(cx), "fn main() { a.first_method() }");
|
||||
});
|
||||
|
||||
// Return a resolved completion from the host's language server.
|
||||
// The resolved completion has an additional text edit.
|
||||
fake_language_server.handle_request::<lsp::request::ResolveCompletionItem, _, _>(
|
||||
|params, _| async move {
|
||||
Ok(match params.label.as_str() {
|
||||
"first_method(…)" => lsp::CompletionItem {
|
||||
label: "first_method(…)".into(),
|
||||
detail: Some("fn(&mut self, B) -> C".into()),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
new_text: "first_method($1)".to_string(),
|
||||
range: lsp::Range::new(
|
||||
lsp::Position::new(0, 14),
|
||||
lsp::Position::new(0, 14),
|
||||
),
|
||||
})),
|
||||
additional_text_edits: Some(vec![lsp::TextEdit {
|
||||
new_text: "use d::SomeTrait;\n".to_string(),
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
|
||||
}]),
|
||||
insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
|
||||
..Default::default()
|
||||
},
|
||||
"second_method(…)" => lsp::CompletionItem {
|
||||
label: "second_method(…)".into(),
|
||||
detail: Some("fn(&mut self, C) -> D<E>".into()),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
new_text: "second_method()".to_string(),
|
||||
range: lsp::Range::new(
|
||||
lsp::Position::new(0, 14),
|
||||
lsp::Position::new(0, 14),
|
||||
),
|
||||
})),
|
||||
insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
|
||||
additional_text_edits: Some(vec![lsp::TextEdit {
|
||||
new_text: "use d::SomeTrait;\n".to_string(),
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
|
||||
}]),
|
||||
..Default::default()
|
||||
},
|
||||
_ => panic!("unexpected completion label: {:?}", params.label),
|
||||
assert_eq!(params.label, "first_method(…)");
|
||||
Ok(lsp::CompletionItem {
|
||||
label: "first_method(…)".into(),
|
||||
detail: Some("fn(&mut self, B) -> C".into()),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
new_text: "first_method($1)".to_string(),
|
||||
range: lsp::Range::new(lsp::Position::new(0, 14), lsp::Position::new(0, 14)),
|
||||
})),
|
||||
additional_text_edits: Some(vec![lsp::TextEdit {
|
||||
new_text: "use d::SomeTrait;\n".to_string(),
|
||||
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
|
||||
}]),
|
||||
insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
|
||||
..Default::default()
|
||||
})
|
||||
},
|
||||
);
|
||||
cx_a.executor().finish_waiting();
|
||||
cx_a.executor().run_until_parked();
|
||||
|
||||
// Confirm a completion on the guest.
|
||||
editor_b
|
||||
.update(cx_b, |editor, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
editor.confirm_completion(&ConfirmCompletion { item_ix: Some(0) }, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx_a.executor().run_until_parked();
|
||||
cx_b.executor().run_until_parked();
|
||||
|
||||
// The additional edit is applied.
|
||||
cx_a.executor().run_until_parked();
|
||||
|
||||
buffer_a.read_with(cx_a, |buffer, _| {
|
||||
assert_eq!(
|
||||
buffer.text(),
|
||||
@@ -540,15 +516,9 @@ async fn test_collaborating_with_completion(cx_a: &mut TestAppContext, cx_b: &mu
|
||||
cx_b.executor().run_until_parked();
|
||||
|
||||
// When accepting the completion, the snippet is insert.
|
||||
editor_b
|
||||
.update(cx_b, |editor, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
editor.confirm_completion(&ConfirmCompletion { item_ix: Some(0) }, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
editor_b.update(cx_b, |editor, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
editor.confirm_completion(&ConfirmCompletion { item_ix: Some(0) }, cx);
|
||||
assert_eq!(
|
||||
editor.text(cx),
|
||||
"use d::SomeTrait;\nfn main() { a.first_method(); a.third_method(, , ) }"
|
||||
|
||||
@@ -27,6 +27,7 @@ use language::{
|
||||
use live_kit_client::MacOSDisplay;
|
||||
use lsp::LanguageServerId;
|
||||
use parking_lot::Mutex;
|
||||
use project::lsp_store::FormatTarget;
|
||||
use project::{
|
||||
lsp_store::FormatTrigger, search::SearchQuery, search::SearchResult, DiagnosticSummary,
|
||||
HoverBlockKind, Project, ProjectPath,
|
||||
@@ -1389,7 +1390,7 @@ async fn test_unshare_project(
|
||||
.unwrap();
|
||||
executor.run_until_parked();
|
||||
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_disconnected()));
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_disconnected(cx)));
|
||||
|
||||
// Client C opens the project.
|
||||
let project_c = client_c.join_remote_project(project_id, cx_c).await;
|
||||
@@ -1402,7 +1403,7 @@ async fn test_unshare_project(
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
assert!(project_c.read_with(cx_c, |project, _| project.is_disconnected()));
|
||||
assert!(project_c.read_with(cx_c, |project, cx| project.is_disconnected(cx)));
|
||||
|
||||
// Client C can open the project again after client A re-shares.
|
||||
let project_id = active_call_a
|
||||
@@ -1427,8 +1428,8 @@ async fn test_unshare_project(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_c2.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
project_c2.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
assert!(project.collaborators().is_empty());
|
||||
});
|
||||
}
|
||||
@@ -1560,8 +1561,8 @@ async fn test_project_reconnect(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, _| {
|
||||
assert!(!project.is_disconnected());
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
@@ -1661,7 +1662,7 @@ async fn test_project_reconnect(
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected());
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1695,9 +1696,9 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
project_b2.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
|
||||
project_b3.read_with(cx_b, |project, _| assert!(!project.is_disconnected()));
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(!project.is_disconnected(cx)));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WaZ"));
|
||||
|
||||
@@ -1754,7 +1755,7 @@ async fn test_project_reconnect(
|
||||
executor.run_until_parked();
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected());
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1788,7 +1789,7 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b3.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WXaYZ"));
|
||||
|
||||
@@ -3816,8 +3817,8 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
@@ -3849,12 +3850,12 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 0);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
project_c.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -4417,6 +4418,7 @@ async fn test_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4450,6 +4452,7 @@ async fn test_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4555,6 +4558,7 @@ async fn test_prettier_formatting_buffer(
|
||||
HashSet::from_iter([buffer_b.clone()]),
|
||||
true,
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -4574,6 +4578,7 @@ async fn test_prettier_formatting_buffer(
|
||||
HashSet::from_iter([buffer_a.clone()]),
|
||||
true,
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -1168,7 +1168,7 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
Some((project, cx))
|
||||
});
|
||||
|
||||
if !guest_project.is_disconnected() {
|
||||
if !guest_project.is_disconnected(cx) {
|
||||
if let Some((host_project, host_cx)) = host_project {
|
||||
let host_worktree_snapshots =
|
||||
host_project.read_with(host_cx, |host_project, cx| {
|
||||
@@ -1254,8 +1254,8 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
|
||||
let buffers = client.buffers().clone();
|
||||
for (guest_project, guest_buffers) in &buffers {
|
||||
let project_id = if guest_project.read_with(client_cx, |project, _| {
|
||||
project.is_local() || project.is_disconnected()
|
||||
let project_id = if guest_project.read_with(client_cx, |project, cx| {
|
||||
project.is_local() || project.is_disconnected(cx)
|
||||
}) {
|
||||
continue;
|
||||
} else {
|
||||
|
||||
@@ -532,9 +532,9 @@ impl<T: RandomizedTest> TestPlan<T> {
|
||||
server.allow_connections();
|
||||
|
||||
for project in client.dev_server_projects().iter() {
|
||||
project.read_with(&client_cx, |project, _| {
|
||||
project.read_with(&client_cx, |project, cx| {
|
||||
assert!(
|
||||
project.is_disconnected(),
|
||||
project.is_disconnected(cx),
|
||||
"project {:?} should be read only",
|
||||
project.remote_id()
|
||||
)
|
||||
|
||||
@@ -2,10 +2,12 @@ use crate::tests::TestServer;
|
||||
use call::ActiveCall;
|
||||
use fs::{FakeFs, Fs as _};
|
||||
use gpui::{Context as _, TestAppContext};
|
||||
use language::language_settings::all_language_settings;
|
||||
use http_client::BlockedHttpClient;
|
||||
use language::{language_settings::all_language_settings, LanguageRegistry};
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::ProjectPath;
|
||||
use remote::SshRemoteClient;
|
||||
use remote_server::HeadlessProject;
|
||||
use remote_server::{HeadlessAppState, HeadlessProject};
|
||||
use serde_json::json;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
@@ -48,8 +50,22 @@ async fn test_sharing_an_ssh_remote_project(
|
||||
|
||||
// User A connects to the remote project via SSH.
|
||||
server_cx.update(HeadlessProject::init);
|
||||
let _headless_project =
|
||||
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
|
||||
let remote_http_client = Arc::new(BlockedHttpClient);
|
||||
let node = NodeRuntime::unavailable();
|
||||
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
|
||||
let _headless_project = server_cx.new_model(|cx| {
|
||||
client::init_settings(cx);
|
||||
HeadlessProject::new(
|
||||
HeadlessAppState {
|
||||
session: server_ssh,
|
||||
fs: remote_fs.clone(),
|
||||
http_client: remote_http_client,
|
||||
node_runtime: node,
|
||||
languages,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let (project_a, worktree_id) = client_a
|
||||
.build_ssh_project("/code/project1", client_ssh, cx_a)
|
||||
|
||||
@@ -635,9 +635,11 @@ impl TestServer {
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
llm_db: None,
|
||||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
stripe_billing: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
clickhouse_client: None,
|
||||
@@ -677,7 +679,6 @@ impl TestServer {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_llm_usage_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
},
|
||||
|
||||
@@ -111,7 +111,7 @@ impl MessageEditor {
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_completion_provider(Box::new(MessageEditorCompletionProvider(this)));
|
||||
editor.set_completion_provider(Some(Box::new(MessageEditorCompletionProvider(this))));
|
||||
editor.set_auto_replace_emoji_shortcode(
|
||||
MessageEditorSettings::get_global(cx)
|
||||
.auto_replace_emoji_shortcode
|
||||
|
||||
@@ -363,12 +363,10 @@ mod tests {
|
||||
|
||||
// Confirming a completion inserts it and hides the context menu, without showing
|
||||
// the copilot suggestion afterwards.
|
||||
editor.confirm_completion(&Default::default(), cx).unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update_editor(|editor, cx| {
|
||||
editor
|
||||
.confirm_completion(&Default::default(), cx)
|
||||
.unwrap()
|
||||
.detach();
|
||||
assert!(!editor.context_menu_visible());
|
||||
assert!(!editor.has_active_inline_completion(cx));
|
||||
assert_eq!(editor.text(cx), "one.completion_a\ntwo\nthree\n");
|
||||
|
||||
@@ -237,6 +237,7 @@ gpui::actions!(
|
||||
ToggleFold,
|
||||
ToggleFoldRecursive,
|
||||
Format,
|
||||
FormatSelections,
|
||||
GoToDeclaration,
|
||||
GoToDeclarationSplit,
|
||||
GoToDefinition,
|
||||
@@ -294,6 +295,7 @@ gpui::actions!(
|
||||
RevealInFileManager,
|
||||
ReverseLines,
|
||||
RevertFile,
|
||||
ReloadFile,
|
||||
RevertSelectedHunks,
|
||||
Rewrap,
|
||||
ScrollCursorBottom,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{ops::ControlFlow, time::Duration};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{channel::oneshot, FutureExt};
|
||||
use gpui::{Task, ViewContext};
|
||||
@@ -7,7 +7,7 @@ use crate::Editor;
|
||||
|
||||
pub struct DebouncedDelay {
|
||||
task: Option<Task<()>>,
|
||||
cancel_channel: Option<oneshot::Sender<ControlFlow<()>>>,
|
||||
cancel_channel: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl DebouncedDelay {
|
||||
@@ -23,22 +23,17 @@ impl DebouncedDelay {
|
||||
F: 'static + Send + FnOnce(&mut Editor, &mut ViewContext<Editor>) -> Task<()>,
|
||||
{
|
||||
if let Some(channel) = self.cancel_channel.take() {
|
||||
channel.send(ControlFlow::Break(())).ok();
|
||||
_ = channel.send(());
|
||||
}
|
||||
|
||||
let (sender, mut receiver) = oneshot::channel::<ControlFlow<()>>();
|
||||
let (sender, mut receiver) = oneshot::channel::<()>();
|
||||
self.cancel_channel = Some(sender);
|
||||
|
||||
drop(self.task.take());
|
||||
self.task = Some(cx.spawn(move |model, mut cx| async move {
|
||||
let mut timer = cx.background_executor().timer(delay).fuse();
|
||||
futures::select_biased! {
|
||||
interrupt = receiver => {
|
||||
match interrupt {
|
||||
Ok(ControlFlow::Break(())) | Err(_) => return,
|
||||
Ok(ControlFlow::Continue(())) => {},
|
||||
}
|
||||
}
|
||||
_ = receiver => return,
|
||||
_ = timer => {}
|
||||
}
|
||||
|
||||
@@ -47,11 +42,4 @@ impl DebouncedDelay {
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn start_now(&mut self) -> Option<Task<()>> {
|
||||
if let Some(channel) = self.cancel_channel.take() {
|
||||
channel.send(ControlFlow::Continue(())).ok();
|
||||
}
|
||||
self.task.take()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,7 +96,9 @@ use language::{
|
||||
CursorShape, Diagnostic, Documentation, IndentKind, IndentSize, Language, OffsetRangeExt,
|
||||
Point, Selection, SelectionGoal, TransactionId,
|
||||
};
|
||||
use language::{point_to_lsp, BufferRow, CharClassifier, Runnable, RunnableRange};
|
||||
use language::{
|
||||
point_to_lsp, BufferRow, CharClassifier, LanguageServerName, Runnable, RunnableRange,
|
||||
};
|
||||
use linked_editing_ranges::refresh_linked_ranges;
|
||||
pub use proposed_changes_editor::{
|
||||
ProposedChangesBuffer, ProposedChangesEditor, ProposedChangesEditorToolbar,
|
||||
@@ -121,10 +123,11 @@ use multi_buffer::{
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use project::project_settings::{GitGutterSetting, ProjectSettings};
|
||||
use project::{
|
||||
lsp_store::FormatTrigger, CodeAction, Completion, CompletionIntent, Item, Location, Project,
|
||||
ProjectPath, ProjectTransaction, TaskSourceKind,
|
||||
lsp_store::{FormatTarget, FormatTrigger},
|
||||
project_settings::{GitGutterSetting, ProjectSettings},
|
||||
CodeAction, Completion, CompletionIntent, DocumentHighlight, InlayHint, Item, Location,
|
||||
LocationLink, Project, ProjectPath, ProjectTransaction, TaskSourceKind,
|
||||
};
|
||||
use rand::prelude::*;
|
||||
use rpc::{proto::*, ErrorExt};
|
||||
@@ -160,11 +163,11 @@ use ui::{
|
||||
};
|
||||
use util::{defer, maybe, post_inc, RangeExt, ResultExt, TryFutureExt};
|
||||
use workspace::item::{ItemHandle, PreviewTabsSettings};
|
||||
use workspace::notifications::{DetachAndPromptErr, NotificationId};
|
||||
use workspace::notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt};
|
||||
use workspace::{
|
||||
searchable::SearchEvent, ItemNavHistory, SplitDirection, ViewId, Workspace, WorkspaceId,
|
||||
};
|
||||
use workspace::{OpenInTerminal, OpenTerminal, TabBarSettings, Toast};
|
||||
use workspace::{Item as WorkspaceItem, OpenInTerminal, OpenTerminal, TabBarSettings, Toast};
|
||||
|
||||
use crate::hover_links::find_url;
|
||||
use crate::signature_help::{SignatureHelpHiddenBy, SignatureHelpState};
|
||||
@@ -546,6 +549,7 @@ pub struct Editor {
|
||||
active_diagnostics: Option<ActiveDiagnosticGroup>,
|
||||
soft_wrap_mode_override: Option<language_settings::SoftWrap>,
|
||||
project: Option<Model<Project>>,
|
||||
semantics_provider: Option<Rc<dyn SemanticsProvider>>,
|
||||
completion_provider: Option<Box<dyn CompletionProvider>>,
|
||||
collaboration_hub: Option<Box<dyn CollaborationHub>>,
|
||||
blink_manager: Model<BlinkManager>,
|
||||
@@ -884,12 +888,12 @@ enum ContextMenu {
|
||||
impl ContextMenu {
|
||||
fn select_first(
|
||||
&mut self,
|
||||
project: Option<&Model<Project>>,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_first(project, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_first(provider, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_first(cx),
|
||||
}
|
||||
true
|
||||
@@ -900,12 +904,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_prev(
|
||||
&mut self,
|
||||
project: Option<&Model<Project>>,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_prev(project, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_prev(provider, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_prev(cx),
|
||||
}
|
||||
true
|
||||
@@ -916,12 +920,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_next(
|
||||
&mut self,
|
||||
project: Option<&Model<Project>>,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_next(project, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_next(provider, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_next(cx),
|
||||
}
|
||||
true
|
||||
@@ -932,12 +936,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_last(
|
||||
&mut self,
|
||||
project: Option<&Model<Project>>,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_last(project, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_last(provider, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_last(cx),
|
||||
}
|
||||
true
|
||||
@@ -991,39 +995,55 @@ struct CompletionsMenu {
|
||||
}
|
||||
|
||||
impl CompletionsMenu {
|
||||
fn select_first(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
fn select_first(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
self.selected_item = 0;
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_prev(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
fn select_prev(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
if self.selected_item > 0 {
|
||||
self.selected_item -= 1;
|
||||
} else {
|
||||
self.selected_item = self.matches.len() - 1;
|
||||
}
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_next(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
fn select_next(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
if self.selected_item + 1 < self.matches.len() {
|
||||
self.selected_item += 1;
|
||||
} else {
|
||||
self.selected_item = 0;
|
||||
}
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_last(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
fn select_last(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
self.selected_item = self.matches.len() - 1;
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -1059,7 +1079,7 @@ impl CompletionsMenu {
|
||||
|
||||
fn attempt_resolve_selected_completion_documentation(
|
||||
&mut self,
|
||||
project: Option<&Model<Project>>,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
let settings = EditorSettings::get_global(cx);
|
||||
@@ -1068,18 +1088,16 @@ impl CompletionsMenu {
|
||||
}
|
||||
|
||||
let completion_index = self.matches[self.selected_item].candidate_id;
|
||||
let Some(project) = project else {
|
||||
let Some(provider) = provider else {
|
||||
return;
|
||||
};
|
||||
|
||||
let resolve_task = project.update(cx, |project, cx| {
|
||||
project.resolve_completions(
|
||||
self.buffer.clone(),
|
||||
vec![completion_index],
|
||||
self.completions.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let resolve_task = provider.resolve_completions(
|
||||
self.buffer.clone(),
|
||||
vec![completion_index],
|
||||
self.completions.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
let delay_ms =
|
||||
EditorSettings::get_global(cx).completion_documentation_secondary_query_debounce;
|
||||
@@ -1671,7 +1689,7 @@ pub(crate) struct NavigationData {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum GotoDefinitionKind {
|
||||
pub enum GotoDefinitionKind {
|
||||
Symbol,
|
||||
Declaration,
|
||||
Type,
|
||||
@@ -1879,10 +1897,17 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
}));
|
||||
let task_inventory = project.read(cx).task_inventory().clone();
|
||||
project_subscriptions.push(cx.observe(&task_inventory, |editor, _, cx| {
|
||||
editor.tasks_update_task = Some(editor.refresh_runnables(cx));
|
||||
}));
|
||||
if let Some(task_inventory) = project
|
||||
.read(cx)
|
||||
.task_store()
|
||||
.read(cx)
|
||||
.task_inventory()
|
||||
.cloned()
|
||||
{
|
||||
project_subscriptions.push(cx.observe(&task_inventory, |editor, _, cx| {
|
||||
editor.tasks_update_task = Some(editor.refresh_runnables(cx));
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1930,6 +1955,7 @@ impl Editor {
|
||||
active_diagnostics: None,
|
||||
soft_wrap_mode_override,
|
||||
completion_provider: project.clone().map(|project| Box::new(project) as _),
|
||||
semantics_provider: project.clone().map(|project| Rc::new(project) as _),
|
||||
collaboration_hub: project.clone().map(|project| Box::new(project) as _),
|
||||
project,
|
||||
blink_manager: blink_manager.clone(),
|
||||
@@ -2298,8 +2324,16 @@ impl Editor {
|
||||
self.custom_context_menu = Some(Box::new(f))
|
||||
}
|
||||
|
||||
pub fn set_completion_provider(&mut self, provider: Box<dyn CompletionProvider>) {
|
||||
self.completion_provider = Some(provider);
|
||||
pub fn set_completion_provider(&mut self, provider: Option<Box<dyn CompletionProvider>>) {
|
||||
self.completion_provider = provider;
|
||||
}
|
||||
|
||||
pub fn semantics_provider(&self) -> Option<Rc<dyn SemanticsProvider>> {
|
||||
self.semantics_provider.clone()
|
||||
}
|
||||
|
||||
pub fn set_semantics_provider(&mut self, provider: Option<Rc<dyn SemanticsProvider>>) {
|
||||
self.semantics_provider = provider;
|
||||
}
|
||||
|
||||
pub fn set_inline_completion_provider<T>(
|
||||
@@ -4034,7 +4068,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
fn refresh_inlay_hints(&mut self, reason: InlayHintRefreshReason, cx: &mut ViewContext<Self>) {
|
||||
if self.project.is_none() || self.mode != EditorMode::Full {
|
||||
if self.semantics_provider.is_none() || self.mode != EditorMode::Full {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4427,49 +4461,16 @@ impl Editor {
|
||||
&mut self,
|
||||
item_ix: Option<usize>,
|
||||
intent: CompletionIntent,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<Task<anyhow::Result<()>>> {
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> Option<Task<std::result::Result<(), anyhow::Error>>> {
|
||||
use language::ToOffset as _;
|
||||
|
||||
let completions_menu = if let ContextMenu::Completions(menu) = self.hide_context_menu(cx)? {
|
||||
menu
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let mut resolve_task_store = completions_menu
|
||||
.selected_completion_documentation_resolve_debounce
|
||||
.lock();
|
||||
let selected_completion_resolve = resolve_task_store.start_now();
|
||||
let menu_pre_resolve = self
|
||||
.completion_documentation_pre_resolve_debounce
|
||||
.start_now();
|
||||
drop(resolve_task_store);
|
||||
|
||||
Some(cx.spawn(|editor, mut cx| async move {
|
||||
match (selected_completion_resolve, menu_pre_resolve) {
|
||||
(None, None) => {}
|
||||
(Some(resolve), None) | (None, Some(resolve)) => resolve.await,
|
||||
(Some(resolve_1), Some(resolve_2)) => {
|
||||
futures::join!(resolve_1, resolve_2);
|
||||
}
|
||||
}
|
||||
if let Some(apply_edits_task) = editor.update(&mut cx, |editor, cx| {
|
||||
editor.apply_resolved_completion(completions_menu, item_ix, intent, cx)
|
||||
})? {
|
||||
apply_edits_task.await?;
|
||||
}
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
fn apply_resolved_completion(
|
||||
&mut self,
|
||||
completions_menu: CompletionsMenu,
|
||||
item_ix: Option<usize>,
|
||||
intent: CompletionIntent,
|
||||
cx: &mut ViewContext<'_, Editor>,
|
||||
) -> Option<Task<anyhow::Result<Option<language::Transaction>>>> {
|
||||
use language::ToOffset as _;
|
||||
|
||||
let mat = completions_menu
|
||||
.matches
|
||||
.get(item_ix.unwrap_or(completions_menu.selected_item))?;
|
||||
@@ -4628,7 +4629,11 @@ impl Editor {
|
||||
// so we should automatically call signature_help
|
||||
self.show_signature_help(&ShowSignatureHelp, cx);
|
||||
}
|
||||
Some(apply_edits)
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
apply_edits.await?;
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn toggle_code_actions(&mut self, action: &ToggleCodeActions, cx: &mut ViewContext<Self>) {
|
||||
@@ -4717,11 +4722,13 @@ impl Editor {
|
||||
);
|
||||
}
|
||||
project.update(cx, |project, cx| {
|
||||
project.task_context_for_location(
|
||||
captured_task_variables,
|
||||
location,
|
||||
cx,
|
||||
)
|
||||
project.task_store().update(cx, |task_store, cx| {
|
||||
task_store.task_context_for_location(
|
||||
captured_task_variables,
|
||||
location,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
@@ -4933,6 +4940,11 @@ impl Editor {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_code_action_providers(&mut self) {
|
||||
self.code_action_providers.clear();
|
||||
self.available_code_actions.take();
|
||||
}
|
||||
|
||||
pub fn push_code_action_provider(
|
||||
&mut self,
|
||||
provider: Arc<dyn CodeActionProvider>,
|
||||
@@ -5020,7 +5032,7 @@ impl Editor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let project = self.project.clone()?;
|
||||
let provider = self.semantics_provider.clone()?;
|
||||
let buffer = self.buffer.read(cx);
|
||||
let newest_selection = self.selections.newest_anchor().clone();
|
||||
let cursor_position = newest_selection.head();
|
||||
@@ -5036,11 +5048,12 @@ impl Editor {
|
||||
.timer(DOCUMENT_HIGHLIGHTS_DEBOUNCE_TIMEOUT)
|
||||
.await;
|
||||
|
||||
let highlights = if let Some(highlights) = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.document_highlights(&cursor_buffer, cursor_buffer_position, cx)
|
||||
let highlights = if let Some(highlights) = cx
|
||||
.update(|cx| {
|
||||
provider.document_highlights(&cursor_buffer, cursor_buffer_position, cx)
|
||||
})
|
||||
.log_err()
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
highlights.await.log_err()
|
||||
} else {
|
||||
@@ -6230,6 +6243,13 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_file(&mut self, _: &ReloadFile, cx: &mut ViewContext<Self>) {
|
||||
let Some(project) = self.project.clone() else {
|
||||
return;
|
||||
};
|
||||
self.reload(project, cx).detach_and_notify_err(cx);
|
||||
}
|
||||
|
||||
pub fn revert_selected_hunks(&mut self, _: &RevertSelectedHunks, cx: &mut ViewContext<Self>) {
|
||||
let revert_changes = self.gather_revert_changes(&self.selections.disjoint_anchors(), cx);
|
||||
if !revert_changes.is_empty() {
|
||||
@@ -7462,7 +7482,7 @@ impl Editor {
|
||||
.context_menu
|
||||
.write()
|
||||
.as_mut()
|
||||
.map(|menu| menu.select_first(self.project.as_ref(), cx))
|
||||
.map(|menu| menu.select_first(self.completion_provider.as_deref(), cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
@@ -7571,7 +7591,7 @@ impl Editor {
|
||||
.context_menu
|
||||
.write()
|
||||
.as_mut()
|
||||
.map(|menu| menu.select_last(self.project.as_ref(), cx))
|
||||
.map(|menu| menu.select_last(self.completion_provider.as_deref(), cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
@@ -7623,25 +7643,25 @@ impl Editor {
|
||||
|
||||
pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_first(self.project.as_ref(), cx);
|
||||
context_menu.select_first(self.completion_provider.as_deref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_prev(self.project.as_ref(), cx);
|
||||
context_menu.select_prev(self.completion_provider.as_deref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_next(self.project.as_ref(), cx);
|
||||
context_menu.select_next(self.completion_provider.as_deref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_last(self.project.as_ref(), cx);
|
||||
context_menu.select_last(self.completion_provider.as_deref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9134,23 +9154,29 @@ impl Editor {
|
||||
.map(|file| (file.worktree_id(cx), file.clone()))
|
||||
.unzip();
|
||||
|
||||
(project.task_inventory().clone(), worktree_id, file)
|
||||
(
|
||||
project.task_store().read(cx).task_inventory().cloned(),
|
||||
worktree_id,
|
||||
file,
|
||||
)
|
||||
});
|
||||
|
||||
let inventory = inventory.read(cx);
|
||||
let tags = mem::take(&mut runnable.tags);
|
||||
let mut tags: Vec<_> = tags
|
||||
.into_iter()
|
||||
.flat_map(|tag| {
|
||||
let tag = tag.0.clone();
|
||||
inventory
|
||||
.list_tasks(
|
||||
file.clone(),
|
||||
Some(runnable.language.clone()),
|
||||
worktree_id,
|
||||
cx,
|
||||
)
|
||||
.as_ref()
|
||||
.into_iter()
|
||||
.flat_map(|inventory| {
|
||||
inventory.read(cx).list_tasks(
|
||||
file.clone(),
|
||||
Some(runnable.language.clone()),
|
||||
worktree_id,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.filter(move |(_, template)| {
|
||||
template.tags.iter().any(|source_tag| source_tag == &tag)
|
||||
})
|
||||
@@ -9608,7 +9634,7 @@ impl Editor {
|
||||
split: bool,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Task<Result<Navigated>> {
|
||||
let Some(workspace) = self.workspace() else {
|
||||
let Some(provider) = self.semantics_provider.clone() else {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
let buffer = self.buffer.read(cx);
|
||||
@@ -9619,13 +9645,9 @@ impl Editor {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let definitions = project.update(cx, |project, cx| match kind {
|
||||
GotoDefinitionKind::Symbol => project.definition(&buffer, head, cx),
|
||||
GotoDefinitionKind::Declaration => project.declaration(&buffer, head, cx),
|
||||
GotoDefinitionKind::Type => project.type_definition(&buffer, head, cx),
|
||||
GotoDefinitionKind::Implementation => project.implementation(&buffer, head, cx),
|
||||
});
|
||||
let Some(definitions) = provider.definitions(&buffer, head, kind, cx) else {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
|
||||
cx.spawn(|editor, mut cx| async move {
|
||||
let definitions = definitions.await?;
|
||||
@@ -9682,9 +9704,7 @@ impl Editor {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(project) = self.project.clone() else {
|
||||
return;
|
||||
};
|
||||
let project = self.project.clone();
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let result = find_file(&buffer, project, buffer_position, &mut cx).await;
|
||||
@@ -9875,21 +9895,19 @@ impl Editor {
|
||||
&self,
|
||||
lsp_location: lsp::Location,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Task<anyhow::Result<Option<Location>>> {
|
||||
let Some(project) = self.project.clone() else {
|
||||
return Task::Ready(Some(Ok(None)));
|
||||
};
|
||||
|
||||
cx.spawn(move |editor, mut cx| async move {
|
||||
let location_task = editor.update(&mut cx, |editor, cx| {
|
||||
let location_task = editor.update(&mut cx, |_, cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
let language_server_name =
|
||||
editor.buffer.read(cx).as_singleton().and_then(|buffer| {
|
||||
project
|
||||
.language_server_for_buffer(buffer.read(cx), server_id, cx)
|
||||
.map(|(lsp_adapter, _)| lsp_adapter.name.clone())
|
||||
});
|
||||
let language_server_name = project
|
||||
.language_server_statuses(cx)
|
||||
.find(|(id, _)| server_id == *id)
|
||||
.map(|(_, status)| LanguageServerName::from(status.name.as_str()));
|
||||
language_server_name.map(|language_server_name| {
|
||||
project.open_local_buffer_via_lsp(
|
||||
lsp_location.uri.clone(),
|
||||
@@ -10086,7 +10104,7 @@ impl Editor {
|
||||
pub fn rename(&mut self, _: &Rename, cx: &mut ViewContext<Self>) -> Option<Task<Result<()>>> {
|
||||
use language::ToOffset as _;
|
||||
|
||||
let project = self.project.clone()?;
|
||||
let provider = self.semantics_provider.clone()?;
|
||||
let selection = self.selections.newest_anchor().clone();
|
||||
let (cursor_buffer, cursor_buffer_position) = self
|
||||
.buffer
|
||||
@@ -10103,9 +10121,9 @@ impl Editor {
|
||||
let snapshot = cursor_buffer.read(cx).snapshot();
|
||||
let cursor_buffer_offset = cursor_buffer_position.to_offset(&snapshot);
|
||||
let cursor_buffer_offset_end = cursor_buffer_position_end.to_offset(&snapshot);
|
||||
let prepare_rename = project.update(cx, |project, cx| {
|
||||
project.prepare_rename(cursor_buffer.clone(), cursor_buffer_offset, cx)
|
||||
});
|
||||
let prepare_rename = provider
|
||||
.range_for_rename(&cursor_buffer, cursor_buffer_position, cx)
|
||||
.unwrap_or_else(|| Task::ready(Ok(None)));
|
||||
drop(snapshot);
|
||||
|
||||
Some(cx.spawn(|this, mut cx| async move {
|
||||
@@ -10276,32 +10294,28 @@ impl Editor {
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
let rename = self.take_rename(false, cx)?;
|
||||
let workspace = self.workspace()?;
|
||||
let (start_buffer, start) = self
|
||||
let workspace = self.workspace()?.downgrade();
|
||||
let (buffer, start) = self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(rename.range.start, cx)?;
|
||||
let (end_buffer, end) = self
|
||||
let (end_buffer, _) = self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(rename.range.end, cx)?;
|
||||
if start_buffer != end_buffer {
|
||||
if buffer != end_buffer {
|
||||
return None;
|
||||
}
|
||||
|
||||
let buffer = start_buffer;
|
||||
let range = start..end;
|
||||
let old_name = rename.old_name;
|
||||
let new_name = rename.editor.read(cx).text(cx);
|
||||
|
||||
let rename = workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.clone()
|
||||
.update(cx, |project, cx| {
|
||||
project.perform_rename(buffer.clone(), range.start, new_name.clone(), true, cx)
|
||||
});
|
||||
let workspace = workspace.downgrade();
|
||||
let rename = self.semantics_provider.as_ref()?.perform_rename(
|
||||
&buffer,
|
||||
start,
|
||||
new_name.clone(),
|
||||
cx,
|
||||
)?;
|
||||
|
||||
Some(cx.spawn(|editor, mut cx| async move {
|
||||
let project_transaction = rename.await?;
|
||||
@@ -10372,13 +10386,39 @@ impl Editor {
|
||||
None => return None,
|
||||
};
|
||||
|
||||
Some(self.perform_format(project, FormatTrigger::Manual, cx))
|
||||
Some(self.perform_format(project, FormatTrigger::Manual, FormatTarget::Buffer, cx))
|
||||
}
|
||||
|
||||
fn format_selections(
|
||||
&mut self,
|
||||
_: &FormatSelections,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
let project = match &self.project {
|
||||
Some(project) => project.clone(),
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let selections = self
|
||||
.selections
|
||||
.all_adjusted(cx)
|
||||
.into_iter()
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect_vec();
|
||||
|
||||
Some(self.perform_format(
|
||||
project,
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Ranges(selections),
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
fn perform_format(
|
||||
&mut self,
|
||||
project: Model<Project>,
|
||||
trigger: FormatTrigger,
|
||||
target: FormatTarget,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let buffer = self.buffer().clone();
|
||||
@@ -10388,7 +10428,9 @@ impl Editor {
|
||||
}
|
||||
|
||||
let mut timeout = cx.background_executor().timer(FORMAT_TIMEOUT).fuse();
|
||||
let format = project.update(cx, |project, cx| project.format(buffers, true, trigger, cx));
|
||||
let format = project.update(cx, |project, cx| {
|
||||
project.format(buffers, true, trigger, target, cx)
|
||||
});
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let transaction = futures::select_biased! {
|
||||
@@ -12356,14 +12398,22 @@ impl Editor {
|
||||
|
||||
let mut new_selections_by_buffer = HashMap::default();
|
||||
for selection in self.selections.all::<usize>(cx) {
|
||||
for (buffer, mut range, _) in
|
||||
buffer.range_to_buffer_ranges(selection.start..selection.end, cx)
|
||||
for (mut buffer_handle, mut range, _) in
|
||||
buffer.range_to_buffer_ranges(selection.range(), cx)
|
||||
{
|
||||
// When editing branch buffers, jump to the corresponding location
|
||||
// in their base buffer.
|
||||
let buffer = buffer_handle.read(cx);
|
||||
if let Some(base_buffer) = buffer.diff_base_buffer() {
|
||||
range = buffer.range_to_version(range, &base_buffer.read(cx).version());
|
||||
buffer_handle = base_buffer;
|
||||
}
|
||||
|
||||
if selection.reversed {
|
||||
mem::swap(&mut range.start, &mut range.end);
|
||||
}
|
||||
new_selections_by_buffer
|
||||
.entry(buffer)
|
||||
.entry(buffer_handle)
|
||||
.or_insert(Vec::new())
|
||||
.push(range)
|
||||
}
|
||||
@@ -12648,24 +12698,13 @@ impl Editor {
|
||||
}
|
||||
|
||||
pub fn supports_inlay_hints(&self, cx: &AppContext) -> bool {
|
||||
let Some(project) = self.project.as_ref() else {
|
||||
let Some(provider) = self.semantics_provider.as_ref() else {
|
||||
return false;
|
||||
};
|
||||
let project = project.read(cx);
|
||||
|
||||
let mut supports = false;
|
||||
self.buffer().read(cx).for_each_buffer(|buffer| {
|
||||
if !supports {
|
||||
supports = project
|
||||
.language_servers_for_buffer(buffer.read(cx), cx)
|
||||
.any(
|
||||
|(_, server)| match server.capabilities().inlay_hint_provider {
|
||||
Some(lsp::OneOf::Left(enabled)) => enabled,
|
||||
Some(lsp::OneOf::Right(_)) => true,
|
||||
None => false,
|
||||
},
|
||||
)
|
||||
}
|
||||
supports |= provider.supports_inlay_hints(buffer, cx);
|
||||
});
|
||||
supports
|
||||
}
|
||||
@@ -12931,6 +12970,62 @@ impl CollaborationHub for Model<Project> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SemanticsProvider {
|
||||
fn hover(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Vec<project::Hover>>>;
|
||||
|
||||
fn inlay_hints(
|
||||
&self,
|
||||
buffer_handle: Model<Buffer>,
|
||||
range: Range<text::Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<Vec<InlayHint>>>>;
|
||||
|
||||
fn resolve_inlay_hint(
|
||||
&self,
|
||||
hint: InlayHint,
|
||||
buffer_handle: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<InlayHint>>>;
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Model<Buffer>, cx: &AppContext) -> bool;
|
||||
|
||||
fn document_highlights(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<DocumentHighlight>>>>;
|
||||
|
||||
fn definitions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
kind: GotoDefinitionKind,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<LocationLink>>>>;
|
||||
|
||||
fn range_for_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Option<Range<text::Anchor>>>>>;
|
||||
|
||||
fn perform_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
new_name: String,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<ProjectTransaction>>>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider {
|
||||
fn completions(
|
||||
&self,
|
||||
@@ -13027,18 +13122,11 @@ fn snippet_completions(
|
||||
return vec![];
|
||||
}
|
||||
let snapshot = buffer.read(cx).text_snapshot();
|
||||
let chunks = snapshot.reversed_chunks_in_range(text::Anchor::MIN..buffer_position);
|
||||
|
||||
let mut lines = chunks.lines();
|
||||
let Some(line_at) = lines.next().filter(|line| !line.is_empty()) else {
|
||||
return vec![];
|
||||
};
|
||||
let chars = snapshot.reversed_chars_for_range(text::Anchor::MIN..buffer_position);
|
||||
|
||||
let scope = language.map(|language| language.default_scope());
|
||||
let classifier = CharClassifier::new(scope).for_completion(true);
|
||||
let mut last_word = line_at
|
||||
.chars()
|
||||
.rev()
|
||||
let mut last_word = chars
|
||||
.take_while(|c| classifier.is_word(*c))
|
||||
.collect::<String>();
|
||||
last_word = last_word.chars().rev().collect();
|
||||
@@ -13182,6 +13270,102 @@ impl CompletionProvider for Model<Project> {
|
||||
}
|
||||
}
|
||||
|
||||
impl SemanticsProvider for Model<Project> {
|
||||
fn hover(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Vec<project::Hover>>> {
|
||||
Some(self.update(cx, |project, cx| project.hover(buffer, position, cx)))
|
||||
}
|
||||
|
||||
fn document_highlights(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<DocumentHighlight>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.document_highlights(buffer, position, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn definitions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
kind: GotoDefinitionKind,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<LocationLink>>>> {
|
||||
Some(self.update(cx, |project, cx| match kind {
|
||||
GotoDefinitionKind::Symbol => project.definition(&buffer, position, cx),
|
||||
GotoDefinitionKind::Declaration => project.declaration(&buffer, position, cx),
|
||||
GotoDefinitionKind::Type => project.type_definition(&buffer, position, cx),
|
||||
GotoDefinitionKind::Implementation => project.implementation(&buffer, position, cx),
|
||||
}))
|
||||
}
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Model<Buffer>, cx: &AppContext) -> bool {
|
||||
// TODO: make this work for remote projects
|
||||
self.read(cx)
|
||||
.language_servers_for_buffer(buffer.read(cx), cx)
|
||||
.any(
|
||||
|(_, server)| match server.capabilities().inlay_hint_provider {
|
||||
Some(lsp::OneOf::Left(enabled)) => enabled,
|
||||
Some(lsp::OneOf::Right(_)) => true,
|
||||
None => false,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn inlay_hints(
|
||||
&self,
|
||||
buffer_handle: Model<Buffer>,
|
||||
range: Range<text::Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<Vec<InlayHint>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.inlay_hints(buffer_handle, range, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn resolve_inlay_hint(
|
||||
&self,
|
||||
hint: InlayHint,
|
||||
buffer_handle: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<InlayHint>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.resolve_inlay_hint(hint, buffer_handle, server_id, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn range_for_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Option<Range<text::Anchor>>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.prepare_rename(buffer.clone(), position, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn perform_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
new_name: String,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<ProjectTransaction>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.perform_rename(buffer.clone(), position, new_name, cx)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn inlay_hint_settings(
|
||||
location: Anchor,
|
||||
snapshot: &MultiBufferSnapshot,
|
||||
@@ -13475,6 +13659,7 @@ pub enum EditorEvent {
|
||||
TransactionBegun {
|
||||
transaction_id: clock::Lamport,
|
||||
},
|
||||
Reloaded,
|
||||
CursorShapeChanged,
|
||||
}
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ pub struct EditorSettingsContent {
|
||||
/// Default: true
|
||||
pub cursor_blink: Option<bool>,
|
||||
/// Cursor shape for the default editor.
|
||||
/// Can be "bar", "block", "underscore", or "hollow".
|
||||
/// Can be "bar", "block", "underline", or "hollow".
|
||||
///
|
||||
/// Default: None
|
||||
pub cursor_shape: Option<CursorShape>,
|
||||
|
||||
@@ -7076,7 +7076,12 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
|
||||
|
||||
let format = editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.perform_format(project.clone(), FormatTrigger::Manual, cx)
|
||||
editor.perform_format(
|
||||
project.clone(),
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
fake_server
|
||||
@@ -7112,7 +7117,7 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
|
||||
});
|
||||
let format = editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.perform_format(project, FormatTrigger::Manual, cx)
|
||||
editor.perform_format(project, FormatTrigger::Manual, FormatTarget::Buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
cx.executor().advance_clock(super::FORMAT_TIMEOUT);
|
||||
@@ -7996,7 +8001,7 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
|
||||
.unwrap()
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
one.ˇ
|
||||
one.second_completionˇ
|
||||
two
|
||||
three
|
||||
"});
|
||||
@@ -8029,9 +8034,9 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
|
||||
cx.assert_editor_state(indoc! {"
|
||||
one.second_completionˇ
|
||||
two
|
||||
thoverlapping additional editree
|
||||
|
||||
additional edit"});
|
||||
three
|
||||
additional edit
|
||||
"});
|
||||
|
||||
cx.set_state(indoc! {"
|
||||
one.second_completion
|
||||
@@ -8091,8 +8096,8 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
one.second_completion
|
||||
two siˇ
|
||||
three siˇ
|
||||
two sixth_completionˇ
|
||||
three sixth_completionˇ
|
||||
additional edit
|
||||
"});
|
||||
|
||||
@@ -8133,11 +8138,9 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
|
||||
.confirm_completion(&ConfirmCompletion::default(), cx)
|
||||
.unwrap()
|
||||
});
|
||||
cx.assert_editor_state("editor.cloˇ");
|
||||
cx.assert_editor_state("editor.closeˇ");
|
||||
handle_resolve_completion_request(&mut cx, None).await;
|
||||
apply_additional_edits.await.unwrap();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
editor.closeˇ"});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -10142,7 +10145,7 @@ async fn test_completions_with_additional_edits(cx: &mut gpui::TestAppContext) {
|
||||
.confirm_completion(&ConfirmCompletion::default(), cx)
|
||||
.unwrap()
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"fn main() { let a = 2.ˇ; }"});
|
||||
cx.assert_editor_state(indoc! {"fn main() { let a = 2.Some(2)ˇ; }"});
|
||||
|
||||
cx.handle_request::<lsp::request::ResolveCompletionItem, _, _>(move |_, _, _| {
|
||||
let task_completion_item = completion_item.clone();
|
||||
@@ -10311,7 +10314,12 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
|
||||
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.perform_format(project.clone(), FormatTrigger::Manual, cx)
|
||||
editor.perform_format(
|
||||
project.clone(),
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await;
|
||||
@@ -10325,7 +10333,12 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::Auto)
|
||||
});
|
||||
let format = editor.update(cx, |editor, cx| {
|
||||
editor.perform_format(project.clone(), FormatTrigger::Manual, cx)
|
||||
editor.perform_format(
|
||||
project.clone(),
|
||||
FormatTrigger::Manual,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
format.await.unwrap();
|
||||
assert_eq!(
|
||||
|
||||
@@ -64,7 +64,7 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use sum_tree::Bias;
|
||||
use theme::{ActiveTheme, PlayerColor};
|
||||
use theme::{ActiveTheme, Appearance, PlayerColor};
|
||||
use ui::prelude::*;
|
||||
use ui::{h_flex, ButtonLike, ButtonStyle, ContextMenu, Tooltip};
|
||||
use util::RangeExt;
|
||||
@@ -376,6 +376,13 @@ impl EditorElement {
|
||||
cx.propagate();
|
||||
}
|
||||
});
|
||||
register_action(view, cx, |editor, action, cx| {
|
||||
if let Some(task) = editor.format_selections(action, cx) {
|
||||
task.detach_and_log_err(cx);
|
||||
} else {
|
||||
cx.propagate();
|
||||
}
|
||||
});
|
||||
register_action(view, cx, Editor::restart_language_server);
|
||||
register_action(view, cx, Editor::cancel_language_server_work);
|
||||
register_action(view, cx, Editor::show_character_palette);
|
||||
@@ -437,7 +444,8 @@ impl EditorElement {
|
||||
register_action(view, cx, Editor::revert_file);
|
||||
register_action(view, cx, Editor::revert_selected_hunks);
|
||||
register_action(view, cx, Editor::apply_selected_diff_hunks);
|
||||
register_action(view, cx, Editor::open_active_item_in_terminal)
|
||||
register_action(view, cx, Editor::open_active_item_in_terminal);
|
||||
register_action(view, cx, Editor::reload_file)
|
||||
}
|
||||
|
||||
fn register_key_listeners(&self, cx: &mut WindowContext, layout: &EditorLayout) {
|
||||
@@ -1015,8 +1023,20 @@ impl EditorElement {
|
||||
block_width = em_width;
|
||||
}
|
||||
let block_text = if let CursorShape::Block = selection.cursor_shape {
|
||||
snapshot.display_chars_at(cursor_position).next().and_then(
|
||||
|(character, _)| {
|
||||
snapshot
|
||||
.display_chars_at(cursor_position)
|
||||
.next()
|
||||
.or_else(|| {
|
||||
if cursor_column == 0 {
|
||||
snapshot
|
||||
.placeholder_text()
|
||||
.and_then(|s| s.chars().next())
|
||||
.map(|c| (c, cursor_position))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.and_then(|(character, _)| {
|
||||
let text = if character == '\n' {
|
||||
SharedString::from(" ")
|
||||
} else {
|
||||
@@ -1031,6 +1051,22 @@ impl EditorElement {
|
||||
})
|
||||
.unwrap_or(self.style.text.font());
|
||||
|
||||
// Invert the text color for the block cursor. Ensure that the text
|
||||
// color is opaque enough to be visible against the background color.
|
||||
//
|
||||
// 0.75 is an arbitrary threshold to determine if the background color is
|
||||
// opaque enough to use as a text color.
|
||||
//
|
||||
// TODO: In the future we should ensure themes have a `text_inverse` color.
|
||||
let color = if cx.theme().colors().editor_background.a < 0.75 {
|
||||
match cx.theme().appearance {
|
||||
Appearance::Dark => Hsla::black(),
|
||||
Appearance::Light => Hsla::white(),
|
||||
}
|
||||
} else {
|
||||
cx.theme().colors().editor_background
|
||||
};
|
||||
|
||||
cx.text_system()
|
||||
.shape_line(
|
||||
text,
|
||||
@@ -1038,15 +1074,14 @@ impl EditorElement {
|
||||
&[TextRun {
|
||||
len,
|
||||
font,
|
||||
color: self.style.background,
|
||||
color,
|
||||
background_color: None,
|
||||
strikethrough: None,
|
||||
underline: None,
|
||||
}],
|
||||
)
|
||||
.log_err()
|
||||
},
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -6060,7 +6095,7 @@ impl CursorLayout {
|
||||
origin: self.origin + origin,
|
||||
size: size(self.block_width, self.line_height),
|
||||
},
|
||||
CursorShape::Underscore => Bounds {
|
||||
CursorShape::Underline => Bounds {
|
||||
origin: self.origin
|
||||
+ origin
|
||||
+ gpui::Point::new(Pixels::ZERO, self.line_height - px(2.0)),
|
||||
|
||||
@@ -403,7 +403,10 @@ impl GitBlame {
|
||||
if this.user_triggered {
|
||||
log::error!("failed to get git blame data: {error:?}");
|
||||
let notification = format!("{:#}", error).trim().to_string();
|
||||
cx.emit(project::Event::Notification(notification));
|
||||
cx.emit(project::Event::Toast {
|
||||
notification_id: "git-blame".into(),
|
||||
message: notification,
|
||||
});
|
||||
} else {
|
||||
// If we weren't triggered by a user, we just log errors in the background, instead of sending
|
||||
// notifications.
|
||||
@@ -619,9 +622,11 @@ mod tests {
|
||||
let event = project.next_event(cx).await;
|
||||
assert_eq!(
|
||||
event,
|
||||
project::Event::Notification(
|
||||
"Failed to blame \"file.txt\": failed to get blame for \"file.txt\"".to_string()
|
||||
)
|
||||
project::Event::Toast {
|
||||
notification_id: "git-blame".into(),
|
||||
message: "Failed to blame \"file.txt\": failed to get blame for \"file.txt\""
|
||||
.to_string()
|
||||
}
|
||||
);
|
||||
|
||||
blame.update(cx, |blame, cx| {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
hover_popover::{self, InlayHover},
|
||||
scroll::ScrollAmount,
|
||||
Anchor, Editor, EditorSnapshot, FindAllReferences, GoToDefinition, GoToTypeDefinition, InlayId,
|
||||
Navigated, PointForPosition, SelectPhase,
|
||||
Anchor, Editor, EditorSnapshot, FindAllReferences, GoToDefinition, GoToTypeDefinition,
|
||||
GotoDefinitionKind, InlayId, Navigated, PointForPosition, SelectPhase,
|
||||
};
|
||||
use gpui::{px, AppContext, AsyncWindowContext, Model, Modifiers, Task, ViewContext};
|
||||
use language::{Bias, ToOffset};
|
||||
@@ -14,12 +14,12 @@ use project::{
|
||||
};
|
||||
use std::ops::Range;
|
||||
use theme::ActiveTheme as _;
|
||||
use util::{maybe, ResultExt, TryFutureExt};
|
||||
use util::{maybe, ResultExt, TryFutureExt as _};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HoveredLinkState {
|
||||
pub last_trigger_point: TriggerPoint,
|
||||
pub preferred_kind: LinkDefinitionKind,
|
||||
pub preferred_kind: GotoDefinitionKind,
|
||||
pub symbol_range: Option<RangeInEditor>,
|
||||
pub links: Vec<HoverLink>,
|
||||
pub task: Option<Task<Option<()>>>,
|
||||
@@ -428,12 +428,6 @@ pub fn update_inlay_link_and_hover_points(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum LinkDefinitionKind {
|
||||
Symbol,
|
||||
Type,
|
||||
}
|
||||
|
||||
pub fn show_link_definition(
|
||||
shift_held: bool,
|
||||
editor: &mut Editor,
|
||||
@@ -442,8 +436,8 @@ pub fn show_link_definition(
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
let preferred_kind = match trigger_point {
|
||||
TriggerPoint::Text(_) if !shift_held => LinkDefinitionKind::Symbol,
|
||||
_ => LinkDefinitionKind::Type,
|
||||
TriggerPoint::Text(_) if !shift_held => GotoDefinitionKind::Symbol,
|
||||
_ => GotoDefinitionKind::Type,
|
||||
};
|
||||
|
||||
let (mut hovered_link_state, is_cached) =
|
||||
@@ -505,6 +499,7 @@ pub fn show_link_definition(
|
||||
editor.hide_hovered_link(cx)
|
||||
}
|
||||
let project = editor.project.clone();
|
||||
let provider = editor.semantics_provider.clone();
|
||||
|
||||
let snapshot = snapshot.buffer_snapshot.clone();
|
||||
hovered_link_state.task = Some(cx.spawn(|this, mut cx| {
|
||||
@@ -522,54 +517,40 @@ pub fn show_link_definition(
|
||||
(range, vec![HoverLink::Url(url)])
|
||||
})
|
||||
.ok()
|
||||
} else if let Some(project) = project {
|
||||
if let Some((filename_range, filename)) =
|
||||
find_file(&buffer, project.clone(), buffer_position, &mut cx).await
|
||||
{
|
||||
let range = maybe!({
|
||||
let start =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.start)?;
|
||||
let end =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
});
|
||||
} else if let Some((filename_range, filename)) =
|
||||
find_file(&buffer, project.clone(), buffer_position, &mut cx).await
|
||||
{
|
||||
let range = maybe!({
|
||||
let start =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.start)?;
|
||||
let end = snapshot.anchor_in_excerpt(excerpt_id, filename_range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
});
|
||||
|
||||
Some((range, vec![HoverLink::File(filename)]))
|
||||
Some((range, vec![HoverLink::File(filename)]))
|
||||
} else if let Some(provider) = provider {
|
||||
let task = cx.update(|cx| {
|
||||
provider.definitions(&buffer, buffer_position, preferred_kind, cx)
|
||||
})?;
|
||||
if let Some(task) = task {
|
||||
task.await.ok().map(|definition_result| {
|
||||
(
|
||||
definition_result.iter().find_map(|link| {
|
||||
link.origin.as_ref().and_then(|origin| {
|
||||
let start = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.start,
|
||||
)?;
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, origin.range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
})
|
||||
}),
|
||||
definition_result.into_iter().map(HoverLink::Text).collect(),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// query the LSP for definition info
|
||||
project
|
||||
.update(&mut cx, |project, cx| match preferred_kind {
|
||||
LinkDefinitionKind::Symbol => {
|
||||
project.definition(&buffer, buffer_position, cx)
|
||||
}
|
||||
|
||||
LinkDefinitionKind::Type => {
|
||||
project.type_definition(&buffer, buffer_position, cx)
|
||||
}
|
||||
})?
|
||||
.await
|
||||
.ok()
|
||||
.map(|definition_result| {
|
||||
(
|
||||
definition_result.iter().find_map(|link| {
|
||||
link.origin.as_ref().and_then(|origin| {
|
||||
let start = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.start,
|
||||
)?;
|
||||
let end = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.end,
|
||||
)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
})
|
||||
}),
|
||||
definition_result
|
||||
.into_iter()
|
||||
.map(HoverLink::Text)
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
@@ -708,10 +689,11 @@ pub(crate) fn find_url(
|
||||
|
||||
pub(crate) async fn find_file(
|
||||
buffer: &Model<language::Buffer>,
|
||||
project: Model<Project>,
|
||||
project: Option<Model<Project>>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Option<(Range<text::Anchor>, ResolvedPath)> {
|
||||
let project = project?;
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()).ok()?;
|
||||
let scope = snapshot.language_scope_at(position);
|
||||
let (range, candidate_file_path) = surrounding_filename(snapshot, position)?;
|
||||
|
||||
@@ -195,32 +195,22 @@ fn show_hover(
|
||||
anchor: Anchor,
|
||||
ignore_timeout: bool,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
) -> Option<()> {
|
||||
if editor.pending_rename.is_some() {
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
|
||||
let snapshot = editor.snapshot(cx);
|
||||
|
||||
let (buffer, buffer_position) =
|
||||
if let Some(output) = editor.buffer.read(cx).text_anchor_for_position(anchor, cx) {
|
||||
output
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
let (buffer, buffer_position) = editor
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(anchor, cx)?;
|
||||
|
||||
let excerpt_id =
|
||||
if let Some((excerpt_id, _, _)) = editor.buffer().read(cx).excerpt_containing(anchor, cx) {
|
||||
excerpt_id
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?;
|
||||
|
||||
let project = if let Some(project) = editor.project.clone() {
|
||||
project
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
let language_registry = editor.project.as_ref()?.read(cx).languages().clone();
|
||||
let provider = editor.semantics_provider.clone()?;
|
||||
|
||||
if !ignore_timeout {
|
||||
if same_info_hover(editor, &snapshot, anchor)
|
||||
@@ -228,7 +218,7 @@ fn show_hover(
|
||||
|| editor.hover_state.diagnostic_popover.is_some()
|
||||
{
|
||||
// Hover triggered from same location as last time. Don't show again.
|
||||
return;
|
||||
return None;
|
||||
} else {
|
||||
hide_hover(editor, cx);
|
||||
}
|
||||
@@ -240,7 +230,7 @@ fn show_hover(
|
||||
.cmp(&anchor, &snapshot.buffer_snapshot)
|
||||
.is_eq()
|
||||
{
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,12 +252,7 @@ fn show_hover(
|
||||
total_delay
|
||||
};
|
||||
|
||||
// query the LSP for hover info
|
||||
let hover_request = cx.update(|cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
project.hover(&buffer, buffer_position, cx)
|
||||
})
|
||||
})?;
|
||||
let hover_request = cx.update(|cx| provider.hover(&buffer, buffer_position, cx))?;
|
||||
|
||||
if let Some(delay) = delay {
|
||||
delay.await;
|
||||
@@ -377,8 +362,11 @@ fn show_hover(
|
||||
this.hover_state.diagnostic_popover = diagnostic_popover;
|
||||
})?;
|
||||
|
||||
let hovers_response = hover_request.await;
|
||||
let language_registry = project.update(&mut cx, |p, _| p.languages().clone())?;
|
||||
let hovers_response = if let Some(hover_request) = hover_request {
|
||||
hover_request.await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let snapshot = this.update(&mut cx, |this, cx| this.snapshot(cx))?;
|
||||
let mut hover_highlights = Vec::with_capacity(hovers_response.len());
|
||||
let mut info_popovers = Vec::with_capacity(hovers_response.len());
|
||||
@@ -451,6 +439,7 @@ fn show_hover(
|
||||
});
|
||||
|
||||
editor.hover_state.info_task = Some(task);
|
||||
None
|
||||
}
|
||||
|
||||
fn same_info_hover(editor: &Editor, snapshot: &EditorSnapshot, anchor: Anchor) -> bool {
|
||||
@@ -536,7 +525,7 @@ async fn parse_blocks(
|
||||
font_family: Some(buffer_font_family),
|
||||
..Default::default()
|
||||
},
|
||||
rule_color: Color::Muted.color(cx),
|
||||
rule_color: cx.theme().colors().border,
|
||||
block_quote_border_color: Color::Muted.color(cx),
|
||||
block_quote: TextStyleRefinement {
|
||||
color: Some(Color::Muted.color(cx)),
|
||||
@@ -821,7 +810,7 @@ mod tests {
|
||||
hover_provider: Some(lsp::HoverProviderCapability::Simple(true)),
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
trigger_characters: Some(vec![".".to_string(), ":".to_string()]),
|
||||
resolve_provider: Some(false),
|
||||
resolve_provider: Some(true),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
@@ -913,15 +902,12 @@ mod tests {
|
||||
assert_eq!(counter.load(atomic::Ordering::Acquire), 1);
|
||||
|
||||
//apply a completion and check it was successfully applied
|
||||
let () = cx
|
||||
.update_editor(|editor, cx| {
|
||||
editor.context_menu_next(&Default::default(), cx);
|
||||
editor
|
||||
.confirm_completion(&ConfirmCompletion::default(), cx)
|
||||
.unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let _apply_additional_edits = cx.update_editor(|editor, cx| {
|
||||
editor.context_menu_next(&Default::default(), cx);
|
||||
editor
|
||||
.confirm_completion(&ConfirmCompletion::default(), cx)
|
||||
.unwrap()
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
one.second_completionˇ
|
||||
two
|
||||
|
||||
@@ -591,21 +591,13 @@ impl InlayHintCache {
|
||||
drop(guard);
|
||||
cx.spawn(|editor, mut cx| async move {
|
||||
let resolved_hint_task = editor.update(&mut cx, |editor, cx| {
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.buffer(buffer_id)
|
||||
.and_then(|buffer| {
|
||||
let project = editor.project.as_ref()?;
|
||||
Some(project.update(cx, |project, cx| {
|
||||
project.resolve_inlay_hint(
|
||||
hint_to_resolve,
|
||||
buffer,
|
||||
server_id,
|
||||
cx,
|
||||
)
|
||||
}))
|
||||
})
|
||||
let buffer = editor.buffer().read(cx).buffer(buffer_id)?;
|
||||
editor.semantics_provider.as_ref()?.resolve_inlay_hint(
|
||||
hint_to_resolve,
|
||||
buffer,
|
||||
server_id,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
if let Some(resolved_hint_task) = resolved_hint_task {
|
||||
let mut resolved_hint =
|
||||
@@ -895,11 +887,13 @@ fn fetch_and_update_hints(
|
||||
) -> Task<anyhow::Result<()>> {
|
||||
cx.spawn(|editor, mut cx| async move {
|
||||
let buffer_snapshot = excerpt_buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
|
||||
let (lsp_request_limiter, multi_buffer_snapshot) = editor.update(&mut cx, |editor, cx| {
|
||||
let multi_buffer_snapshot = editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx));
|
||||
let lsp_request_limiter = Arc::clone(&editor.inlay_hint_cache.lsp_request_limiter);
|
||||
(lsp_request_limiter, multi_buffer_snapshot)
|
||||
})?;
|
||||
let (lsp_request_limiter, multi_buffer_snapshot) =
|
||||
editor.update(&mut cx, |editor, cx| {
|
||||
let multi_buffer_snapshot =
|
||||
editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx));
|
||||
let lsp_request_limiter = Arc::clone(&editor.inlay_hint_cache.lsp_request_limiter);
|
||||
(lsp_request_limiter, multi_buffer_snapshot)
|
||||
})?;
|
||||
|
||||
let (lsp_request_guard, got_throttled) = if query.invalidate.should_invalidate() {
|
||||
(None, false)
|
||||
@@ -909,12 +903,15 @@ fn fetch_and_update_hints(
|
||||
None => (Some(lsp_request_limiter.acquire().await), true),
|
||||
}
|
||||
};
|
||||
let fetch_range_to_log =
|
||||
fetch_range.start.to_point(&buffer_snapshot)..fetch_range.end.to_point(&buffer_snapshot);
|
||||
let fetch_range_to_log = fetch_range.start.to_point(&buffer_snapshot)
|
||||
..fetch_range.end.to_point(&buffer_snapshot);
|
||||
let inlay_hints_fetch_task = editor
|
||||
.update(&mut cx, |editor, cx| {
|
||||
if got_throttled {
|
||||
let query_not_around_visible_range = match editor.excerpts_for_inlay_hints_query(None, cx).remove(&query.excerpt_id) {
|
||||
let query_not_around_visible_range = match editor
|
||||
.excerpts_for_inlay_hints_query(None, cx)
|
||||
.remove(&query.excerpt_id)
|
||||
{
|
||||
Some((_, _, current_visible_range)) => {
|
||||
let visible_offset_length = current_visible_range.len();
|
||||
let double_visible_range = current_visible_range
|
||||
@@ -928,11 +925,11 @@ fn fetch_and_update_hints(
|
||||
.contains(&fetch_range.start.to_offset(&buffer_snapshot))
|
||||
&& !double_visible_range
|
||||
.contains(&fetch_range.end.to_offset(&buffer_snapshot))
|
||||
},
|
||||
}
|
||||
None => true,
|
||||
};
|
||||
if query_not_around_visible_range {
|
||||
log::trace!("Fetching inlay hints for range {fetch_range_to_log:?} got throttled and fell off the current visible range, skipping.");
|
||||
// log::trace!("Fetching inlay hints for range {fetch_range_to_log:?} got throttled and fell off the current visible range, skipping.");
|
||||
if let Some(task_ranges) = editor
|
||||
.inlay_hint_cache
|
||||
.update_tasks
|
||||
@@ -943,16 +940,12 @@ fn fetch_and_update_hints(
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = editor.buffer().read(cx).buffer(query.buffer_id)?;
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.buffer(query.buffer_id)
|
||||
.and_then(|buffer| {
|
||||
let project = editor.project.as_ref()?;
|
||||
Some(project.update(cx, |project, cx| {
|
||||
project.inlay_hints(buffer, fetch_range.clone(), cx)
|
||||
}))
|
||||
})
|
||||
.semantics_provider
|
||||
.as_ref()?
|
||||
.inlay_hints(buffer, fetch_range.clone(), cx)
|
||||
})
|
||||
.ok()
|
||||
.flatten();
|
||||
@@ -1004,12 +997,12 @@ fn fetch_and_update_hints(
|
||||
})
|
||||
.await;
|
||||
if let Some(new_update) = new_update {
|
||||
log::debug!(
|
||||
"Applying update for range {fetch_range_to_log:?}: remove from editor: {}, remove from cache: {}, add to cache: {}",
|
||||
new_update.remove_from_visible.len(),
|
||||
new_update.remove_from_cache.len(),
|
||||
new_update.add_to_cache.len()
|
||||
);
|
||||
// log::debug!(
|
||||
// "Applying update for range {fetch_range_to_log:?}: remove from editor: {}, remove from cache: {}, add to cache: {}",
|
||||
// new_update.remove_from_visible.len(),
|
||||
// new_update.remove_from_cache.len(),
|
||||
// new_update.add_to_cache.len()
|
||||
// );
|
||||
log::trace!("New update: {new_update:?}");
|
||||
editor
|
||||
.update(&mut cx, |editor, cx| {
|
||||
|
||||
@@ -27,6 +27,7 @@ use rpc::proto::{self, update_view, PeerId};
|
||||
use settings::Settings;
|
||||
use workspace::item::{Dedup, ItemSettings, SerializableItem, TabContentParams};
|
||||
|
||||
use project::lsp_store::FormatTarget;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
borrow::Cow,
|
||||
@@ -722,7 +723,12 @@ impl Item for Editor {
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
if format {
|
||||
this.update(&mut cx, |editor, cx| {
|
||||
editor.perform_format(project.clone(), FormatTrigger::Save, cx)
|
||||
editor.perform_format(
|
||||
project.clone(),
|
||||
FormatTrigger::Save,
|
||||
FormatTarget::Buffer,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::actions::FormatSelections;
|
||||
use crate::{
|
||||
actions::Format, selections_collection::SelectionsCollection, Copy, CopyPermalinkToLine, Cut,
|
||||
DisplayPoint, DisplaySnapshot, Editor, EditorMode, FindAllReferences, GoToDeclaration,
|
||||
@@ -8,6 +7,8 @@ use crate::{
|
||||
};
|
||||
use gpui::prelude::FluentBuilder;
|
||||
use gpui::{DismissEvent, Pixels, Point, Subscription, View, ViewContext};
|
||||
use std::ops::Range;
|
||||
use text::PointUtf16;
|
||||
use workspace::OpenInTerminal;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -164,6 +165,12 @@ pub fn deploy_context_menu(
|
||||
} else {
|
||||
"Reveal in File Manager"
|
||||
};
|
||||
let has_selections = editor
|
||||
.selections
|
||||
.all::<PointUtf16>(cx)
|
||||
.into_iter()
|
||||
.any(|s| !s.is_empty());
|
||||
|
||||
ui::ContextMenu::build(cx, |menu, _cx| {
|
||||
let builder = menu
|
||||
.on_blur_subscription(Subscription::new(|| {}))
|
||||
@@ -175,6 +182,9 @@ pub fn deploy_context_menu(
|
||||
.separator()
|
||||
.action("Rename Symbol", Box::new(Rename))
|
||||
.action("Format Buffer", Box::new(Format))
|
||||
.when(has_selections, |cx| {
|
||||
cx.action("Format Selections", Box::new(FormatSelections))
|
||||
})
|
||||
.action(
|
||||
"Code Actions",
|
||||
Box::new(ToggleCodeActions {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Editor, EditorEvent};
|
||||
use crate::{Editor, EditorEvent, SemanticsProvider};
|
||||
use collections::HashSet;
|
||||
use futures::{channel::mpsc, future::join_all};
|
||||
use gpui::{AppContext, EventEmitter, FocusableView, Model, Render, Subscription, Task, View};
|
||||
@@ -6,7 +6,7 @@ use language::{Buffer, BufferEvent, Capability};
|
||||
use multi_buffer::{ExcerptRange, MultiBuffer};
|
||||
use project::Project;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{any::TypeId, ops::Range, time::Duration};
|
||||
use std::{any::TypeId, ops::Range, rc::Rc, time::Duration};
|
||||
use text::ToOffset;
|
||||
use ui::prelude::*;
|
||||
use workspace::{
|
||||
@@ -35,6 +35,12 @@ struct RecalculateDiff {
|
||||
debounce: bool,
|
||||
}
|
||||
|
||||
/// A provider of code semantics for branch buffers.
|
||||
///
|
||||
/// Requests in edited regions will return nothing, but requests in unchanged
|
||||
/// regions will be translated into the base buffer's coordinates.
|
||||
struct BranchBufferSemanticsProvider(Rc<dyn SemanticsProvider>);
|
||||
|
||||
impl ProposedChangesEditor {
|
||||
pub fn new<T: ToOffset>(
|
||||
buffers: Vec<ProposedChangesBuffer<T>>,
|
||||
@@ -66,6 +72,13 @@ impl ProposedChangesEditor {
|
||||
editor: cx.new_view(|cx| {
|
||||
let mut editor = Editor::for_multibuffer(multibuffer.clone(), project, true, cx);
|
||||
editor.set_expand_all_diff_hunks();
|
||||
editor.set_completion_provider(None);
|
||||
editor.clear_code_action_providers();
|
||||
editor.set_semantics_provider(
|
||||
editor
|
||||
.semantics_provider()
|
||||
.map(|provider| Rc::new(BranchBufferSemanticsProvider(provider)) as _),
|
||||
);
|
||||
editor
|
||||
}),
|
||||
recalculate_diffs_tx,
|
||||
@@ -76,7 +89,7 @@ impl ProposedChangesEditor {
|
||||
|
||||
while recalculate_diff.debounce {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(250))
|
||||
.timer(Duration::from_millis(50))
|
||||
.await;
|
||||
let mut had_further_changes = false;
|
||||
while let Ok(next_recalculate_diff) = recalculate_diffs_rx.try_next() {
|
||||
@@ -245,3 +258,103 @@ impl ToolbarItemView for ProposedChangesEditorToolbar {
|
||||
self.get_toolbar_item_location()
|
||||
}
|
||||
}
|
||||
|
||||
impl BranchBufferSemanticsProvider {
|
||||
fn to_base(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
positions: &[text::Anchor],
|
||||
cx: &AppContext,
|
||||
) -> Option<Model<Buffer>> {
|
||||
let base_buffer = buffer.read(cx).diff_base_buffer()?;
|
||||
let version = base_buffer.read(cx).version();
|
||||
if positions
|
||||
.iter()
|
||||
.any(|position| !version.observed(position.timestamp))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
Some(base_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl SemanticsProvider for BranchBufferSemanticsProvider {
|
||||
fn hover(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Vec<project::Hover>>> {
|
||||
let buffer = self.to_base(buffer, &[position], cx)?;
|
||||
self.0.hover(&buffer, position, cx)
|
||||
}
|
||||
|
||||
fn inlay_hints(
|
||||
&self,
|
||||
buffer: Model<Buffer>,
|
||||
range: Range<text::Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<Vec<project::InlayHint>>>> {
|
||||
let buffer = self.to_base(&buffer, &[range.start, range.end], cx)?;
|
||||
self.0.inlay_hints(buffer, range, cx)
|
||||
}
|
||||
|
||||
fn resolve_inlay_hint(
|
||||
&self,
|
||||
hint: project::InlayHint,
|
||||
buffer: Model<Buffer>,
|
||||
server_id: lsp::LanguageServerId,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<project::InlayHint>>> {
|
||||
let buffer = self.to_base(&buffer, &[], cx)?;
|
||||
self.0.resolve_inlay_hint(hint, buffer, server_id, cx)
|
||||
}
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Model<Buffer>, cx: &AppContext) -> bool {
|
||||
if let Some(buffer) = self.to_base(&buffer, &[], cx) {
|
||||
self.0.supports_inlay_hints(&buffer, cx)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn document_highlights(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<gpui::Result<Vec<project::DocumentHighlight>>>> {
|
||||
let buffer = self.to_base(&buffer, &[position], cx)?;
|
||||
self.0.document_highlights(&buffer, position, cx)
|
||||
}
|
||||
|
||||
fn definitions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
kind: crate::GotoDefinitionKind,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<gpui::Result<Vec<project::LocationLink>>>> {
|
||||
let buffer = self.to_base(&buffer, &[position], cx)?;
|
||||
self.0.definitions(&buffer, position, kind, cx)
|
||||
}
|
||||
|
||||
fn range_for_rename(
|
||||
&self,
|
||||
_: &Model<Buffer>,
|
||||
_: text::Anchor,
|
||||
_: &mut AppContext,
|
||||
) -> Option<Task<gpui::Result<Option<Range<text::Anchor>>>>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn perform_rename(
|
||||
&self,
|
||||
_: &Model<Buffer>,
|
||||
_: text::Anchor,
|
||||
_: String,
|
||||
_: &mut AppContext,
|
||||
) -> Option<Task<gpui::Result<project::ProjectTransaction>>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +67,11 @@ fn task_context_with_editor(
|
||||
variables
|
||||
};
|
||||
|
||||
let context_task = project.update(cx, |project, cx| {
|
||||
project.task_context_for_location(captured_variables, location.clone(), cx)
|
||||
});
|
||||
cx.spawn(|_| context_task)
|
||||
project.update(cx, |project, cx| {
|
||||
project.task_store().update(cx, |task_store, cx| {
|
||||
task_store.task_context_for_location(captured_variables, location, cx)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn task_context(workspace: &Workspace, cx: &mut WindowContext<'_>) -> AsyncTask<TaskContext> {
|
||||
|
||||
@@ -13,7 +13,6 @@ use itertools::Itertools;
|
||||
use language::{Buffer, BufferSnapshot, LanguageRegistry};
|
||||
use multi_buffer::{ExcerptRange, ToPoint};
|
||||
use parking_lot::RwLock;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{FakeFs, Project};
|
||||
use std::{
|
||||
any::TypeId,
|
||||
|
||||
@@ -25,7 +25,6 @@ fs.workspace = true
|
||||
git.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
isahc_http_client.workspace = true
|
||||
language.workspace = true
|
||||
languages.workspace = true
|
||||
node_runtime.workspace = true
|
||||
@@ -36,3 +35,4 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
|
||||
@@ -12,6 +12,7 @@ use language::LanguageRegistry;
|
||||
use node_runtime::NodeRuntime;
|
||||
use open_ai::OpenAiEmbeddingModel;
|
||||
use project::Project;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use semantic_index::{
|
||||
EmbeddingProvider, OpenAiEmbeddingProvider, ProjectIndex, SemanticDb, Status,
|
||||
};
|
||||
@@ -100,7 +101,7 @@ fn main() -> Result<()> {
|
||||
|
||||
gpui::App::headless().run(move |cx| {
|
||||
let executor = cx.background_executor().clone();
|
||||
let client = isahc_http_client::IsahcHttpClient::new(None, None);
|
||||
let client = Arc::new(ReqwestClient::user_agent("Zed LLM evals").unwrap());
|
||||
cx.set_http_client(client.clone());
|
||||
match cli.command {
|
||||
Commands::Fetch {} => {
|
||||
|
||||
@@ -56,7 +56,6 @@ wit-component.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
isahc_http_client.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
@@ -64,5 +63,5 @@ gpui = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
parking_lot.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
tokio.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user