Compare commits
2 Commits
nightly-tr
...
context-me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3772a467e0 | ||
|
|
1be27e7270 |
2
.github/workflows/bump_collab_staging.yml
vendored
2
.github/workflows/bump_collab_staging.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/bump_patch_version.yml
vendored
2
.github/workflows/bump_patch_version.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.branch }}
|
||||
ssh-key: ${{ secrets.ZED_BOT_DEPLOY_KEY }}
|
||||
|
||||
61
.github/workflows/ci.yml
vendored
61
.github/workflows/ci.yml
vendored
@@ -34,7 +34,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -96,13 +96,10 @@ jobs:
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build collab
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p collab
|
||||
run: cargo build -p collab
|
||||
|
||||
- name: Build other binaries and features
|
||||
run: |
|
||||
RUSTFLAGS="-D warnings" cargo build --workspace --bins --all-features
|
||||
cargo check -p gpui --features "macos-blade"
|
||||
RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
run: cargo build --workspace --bins --all-features; cargo check -p gpui --features "macos-blade"
|
||||
|
||||
linux_tests:
|
||||
timeout-minutes: 60
|
||||
@@ -114,7 +111,7 @@ jobs:
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -134,33 +131,7 @@ jobs:
|
||||
uses: ./.github/actions/run_tests
|
||||
|
||||
- name: Build Zed
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p zed
|
||||
|
||||
build_remote_server:
|
||||
timeout-minutes: 60
|
||||
name: (Linux) Build Remote Server
|
||||
runs-on:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@23bce251a8cd2ffc3c1075eaa2367cf899916d84 # v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
cache-provider: "buildjet"
|
||||
|
||||
- name: Install Clang & Mold
|
||||
run: ./script/remote-server && ./script/install-mold 2.34.0
|
||||
|
||||
- name: Build Remote Server
|
||||
run: RUSTFLAGS="-D warnings" cargo build -p remote_server
|
||||
run: cargo build -p zed
|
||||
|
||||
# todo(windows): Actually run the tests
|
||||
windows_tests:
|
||||
@@ -169,7 +140,7 @@ jobs:
|
||||
runs-on: hosted-windows-1
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -184,7 +155,7 @@ jobs:
|
||||
run: cargo xtask clippy
|
||||
|
||||
- name: Build Zed
|
||||
run: $env:RUSTFLAGS="-D warnings"; cargo build
|
||||
run: cargo build -p zed
|
||||
|
||||
bundle-mac:
|
||||
timeout-minutes: 60
|
||||
@@ -210,7 +181,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
# We need to fetch more than one commit so that `script/draft-release-notes`
|
||||
# is able to diff between the current and previous tag.
|
||||
@@ -248,20 +219,20 @@ jobs:
|
||||
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
|
||||
|
||||
- name: Upload app bundle (universal) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
|
||||
path: target/release/Zed.dmg
|
||||
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
|
||||
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
|
||||
|
||||
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
|
||||
@@ -295,7 +266,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -312,7 +283,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
|
||||
@@ -342,7 +313,7 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -359,7 +330,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4
|
||||
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
|
||||
|
||||
2
.github/workflows/close_stale_issues.yml
vendored
2
.github/workflows/close_stale_issues.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
We're working to clean up our issue tracker by closing older issues that might not be relevant anymore. Are you able to reproduce this issue in the latest version of Zed? If so, please let us know by commenting on this issue and we will keep it open; otherwise, we'll close it in 7 days. Feel free to open a new issue if you're seeing this message after the issue has been closed.
|
||||
|
||||
Thanks for your help!
|
||||
close-issue-message: "This issue was closed due to inactivity. If you're still experiencing this problem, feel free to ping a Zed team member to reopen this issue or open a new one."
|
||||
close-issue-message: "This issue was closed due to inactivity; feel free to open a new issue if you're still experiencing this problem!"
|
||||
# We will increase `days-before-stale` to 365 on or after Jan 24th,
|
||||
# 2024. This date marks one year since migrating issues from
|
||||
# 'community' to 'zed' repository. The migration added activity to all
|
||||
|
||||
2
.github/workflows/danger.yml
vendored
2
.github/workflows/danger.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
|
||||
10
.github/workflows/deploy_cloudflare.yml
vendored
10
.github/workflows/deploy_cloudflare.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -36,28 +36,28 @@ jobs:
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Deploy Docs
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: pages deploy target/deploy --project-name=docs
|
||||
|
||||
- name: Deploy Install
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
|
||||
|
||||
- name: Deploy Docs Workers
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: deploy .cloudflare/docs-proxy/src/worker.js
|
||||
|
||||
- name: Deploy Install Workers
|
||||
uses: cloudflare/wrangler-action@9681c2997648301493e78cacbfb790a9f19c833f # v3
|
||||
uses: cloudflare/wrangler-action@168bc28b7078db16f6f1ecc26477fc2248592143 # v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
|
||||
11
.github/workflows/deploy_collab.yml
vendored
11
.github/workflows/deploy_collab.yml
vendored
@@ -3,7 +3,8 @@ name: Publish Collab Server Image
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- collab-production
|
||||
# Pause production deploys while we investigate an issue.
|
||||
# - collab-production
|
||||
- collab-staging
|
||||
|
||||
env:
|
||||
@@ -17,7 +18,7 @@ jobs:
|
||||
- test
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -36,7 +37,7 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
@@ -71,7 +72,7 @@ jobs:
|
||||
run: doctl registry login
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
@@ -97,7 +98,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@fe02b34f77f8bc703788d5817da081398fad5dd2 # v4.0.0
|
||||
with:
|
||||
|
||||
50
.github/workflows/mark-for-nightly.yml
vendored
50
.github/workflows/mark-for-nightly.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: Mark for Nightly
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
mark-for-nightly:
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/nightly')
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get PR details and check if open
|
||||
id: pr_details
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
script: |
|
||||
const pr = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number
|
||||
});
|
||||
if (pr.data.state !== 'open') {
|
||||
console.log('PR is not open. Skipping.');
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
sha: pr.data.head.sha,
|
||||
ref: pr.data.head.ref
|
||||
}
|
||||
|
||||
- name: Edit comment with SHA
|
||||
if: steps.pr_details.outputs.result != 'null'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
script: |
|
||||
const sha = '${{ fromJson(steps.pr_details.outputs.result).sha }}';
|
||||
const originalBody = context.payload.comment.body;
|
||||
const updatedBody = originalBody.replace('/nightly', `/nightly:${sha}`);
|
||||
const finalBody = `${updatedBody}\n\nThis PR's current HEAD (${sha}) will be included in the next nightly build. To remove it from the nightly build, comment with \`/nightly:remove\`.`;
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: context.payload.comment.id,
|
||||
body: finalBody
|
||||
});
|
||||
2
.github/workflows/publish_extension_cli.yml
vendored
2
.github/workflows/publish_extension_cli.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
- ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/randomized_tests.yml
vendored
2
.github/workflows/randomized_tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
|
||||
2
.github/workflows/release_actions.yml
vendored
2
.github/workflows/release_actions.yml
vendored
@@ -1,5 +1,3 @@
|
||||
name: Release Actions
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
76
.github/workflows/release_nightly.yml
vendored
76
.github/workflows/release_nightly.yml
vendored
@@ -14,63 +14,6 @@ env:
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
prepare-nightly-train:
|
||||
name: Prepare nightly-train branch
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Create or update nightly-train branch
|
||||
run: |
|
||||
git config user.name github-actions
|
||||
git config user.email github-actions@github.com
|
||||
git checkout -B nightly-train main
|
||||
|
||||
# Find all comments with /nightly:SHA
|
||||
nightly_comments=$(gh api graphql -f query='
|
||||
query($owner:String!, $repo:String!) {
|
||||
repository(owner:$owner, name:$repo) {
|
||||
pullRequests(last:128, states:OPEN) {
|
||||
nodes {
|
||||
number
|
||||
comments(last:128) {
|
||||
nodes {
|
||||
body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}' -f owner=${{ github.repository_owner }} -f repo=${{ github.event.repository.name }} | jq -r '.data.repository.pullRequests.nodes[] | select(.comments.nodes[].body | contains("/nightly:")) | {number: .number, comments: [.comments.nodes[].body | select(contains("/nightly:"))]} | @json')
|
||||
|
||||
# Process each PR
|
||||
echo "$nightly_comments" | jq -c '.' | while read -r pr; do
|
||||
pr_number=$(echo "$pr" | jq -r '.number')
|
||||
last_command=$(echo "$pr" | jq -r '.comments[-1]')
|
||||
|
||||
sha=$(echo "$last_command" | grep -oP '/nightly:\K[a-f0-9]+')
|
||||
if [ -n "$sha" ]; then
|
||||
if git cherry-pick $sha; then
|
||||
if cargo check --workspace; then
|
||||
echo "Successfully cherry-picked and checked $sha from PR #$pr_number"
|
||||
else
|
||||
git reset --hard HEAD~1
|
||||
echo "::warning::Commit $sha from PR #$pr_number failed cargo check, excluding from nightly-train"
|
||||
fi
|
||||
else
|
||||
echo "::warning::Failed to cherry-pick $sha from PR #$pr_number, skipping"
|
||||
git cherry-pick --abort
|
||||
fi
|
||||
elif [[ $last_command == *"/nightly:remove"* ]]; then
|
||||
echo "Skipping PR #$pr_number due to /nightly:remove command"
|
||||
fi
|
||||
done
|
||||
|
||||
git push origin nightly-train --force
|
||||
|
||||
style:
|
||||
timeout-minutes: 60
|
||||
name: Check formatting and Clippy lints
|
||||
@@ -78,14 +21,12 @@ jobs:
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- test
|
||||
needs: prepare-nightly-train
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
fetch-depth: 0
|
||||
ref: nightly-train
|
||||
|
||||
- name: Run style checks
|
||||
uses: ./.github/actions/check_style
|
||||
@@ -103,10 +44,9 @@ jobs:
|
||||
needs: style
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
ref: nightly-train
|
||||
|
||||
- name: Run tests
|
||||
uses: ./.github/actions/run_tests
|
||||
@@ -135,10 +75,9 @@ jobs:
|
||||
node-version: "18"
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
ref: nightly-train
|
||||
|
||||
- name: Set release channel to nightly
|
||||
run: |
|
||||
@@ -170,10 +109,9 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
ref: nightly-train
|
||||
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
@@ -211,10 +149,9 @@ jobs:
|
||||
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
clean: false
|
||||
ref: nightly-train
|
||||
|
||||
- name: Install Linux dependencies
|
||||
run: ./script/linux
|
||||
@@ -245,10 +182,9 @@ jobs:
|
||||
- bundle-linux-arm
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: nightly-train
|
||||
|
||||
- name: Update nightly tag
|
||||
run: |
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
name: Update All Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */12 * * *"
|
||||
@@ -10,16 +8,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
name: Update Weekly Top Ranking Issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 15 * * *"
|
||||
@@ -10,16 +8,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "script/update_top_ranking_issues/pyproject.toml"
|
||||
- name: Install Python 3.13
|
||||
run: uv python install 3.13
|
||||
- name: Install dependencies
|
||||
run: uv sync --project script/update_top_ranking_issues -p 3.13
|
||||
- name: Run script
|
||||
run: uv run --project script/update_top_ranking_issues script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
python-version: "3.11"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
|
||||
513
Cargo.lock
generated
513
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
21
Cargo.toml
21
Cargo.toml
@@ -122,7 +122,7 @@ members = [
|
||||
"crates/ui",
|
||||
"crates/ui_input",
|
||||
"crates/ui_macros",
|
||||
"crates/reqwest_client",
|
||||
"crates/ureq_client",
|
||||
"crates/util",
|
||||
"crates/vcs_menu",
|
||||
"crates/vim",
|
||||
@@ -145,6 +145,7 @@ members = [
|
||||
"extensions/elm",
|
||||
"extensions/emmet",
|
||||
"extensions/erlang",
|
||||
"extensions/gleam",
|
||||
"extensions/glsl",
|
||||
"extensions/haskell",
|
||||
"extensions/html",
|
||||
@@ -156,6 +157,7 @@ members = [
|
||||
"extensions/proto",
|
||||
"extensions/purescript",
|
||||
"extensions/ruff",
|
||||
"extensions/ruby",
|
||||
"extensions/slash-commands-example",
|
||||
"extensions/snippets",
|
||||
"extensions/svelte",
|
||||
@@ -219,7 +221,7 @@ git = { path = "crates/git" }
|
||||
git_hosting_providers = { path = "crates/git_hosting_providers" }
|
||||
go_to_line = { path = "crates/go_to_line" }
|
||||
google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui", default-features = false, features = ["http_client"]}
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
headless = { path = "crates/headless" }
|
||||
html_to_markdown = { path = "crates/html_to_markdown" }
|
||||
@@ -299,6 +301,7 @@ title_bar = { path = "crates/title_bar" }
|
||||
ui = { path = "crates/ui" }
|
||||
ui_input = { path = "crates/ui_input" }
|
||||
ui_macros = { path = "crates/ui_macros" }
|
||||
ureq_client = { path = "crates/ureq_client" }
|
||||
util = { path = "crates/util" }
|
||||
vcs_menu = { path = "crates/vcs_menu" }
|
||||
vim = { path = "crates/vim" }
|
||||
@@ -326,7 +329,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.5.0"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.24"
|
||||
async-tungstenite = "0.28"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
base64 = "0.22"
|
||||
@@ -335,7 +338,6 @@ blade-graphics = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "e142a3a5e678eb6a13e642ad8401b1f3aa38e969" }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.18"
|
||||
cargo_toml = "0.20"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
@@ -391,14 +393,14 @@ pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29", default-features = false, features = ["charset", "http2", "macos-system-configuration", "rustls-tls-native-roots", "stream"]}
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29" }
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { version = "0.15", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
] }
|
||||
rustc-demangle = "0.1.23"
|
||||
rust-embed = { version = "8.4", features = ["include-exclude"] }
|
||||
rustls = "0.20.3"
|
||||
rustls = "0.21.12"
|
||||
rustls-native-certs = "0.8.0"
|
||||
schemars = { version = "0.8", features = ["impl_json_schema"] }
|
||||
semver = "1.0"
|
||||
@@ -436,7 +438,7 @@ time = { version = "0.3", features = [
|
||||
] }
|
||||
tiny_http = "0.8"
|
||||
toml = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = "0.4.4"
|
||||
tree-sitter = { version = "0.23", features = ["wasm"] }
|
||||
tree-sitter-bash = "0.23"
|
||||
@@ -449,7 +451,6 @@ tree-sitter-go = "0.23"
|
||||
tree-sitter-go-mod = { git = "https://github.com/zed-industries/tree-sitter-go-mod", rev = "a9aea5e358cde4d0f8ff20b7bc4fa311e359c7ca", package = "tree-sitter-gomod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/zed-industries/tree-sitter-go-work", rev = "acb0617bf7f4fda02c6217676cc64acb89536dc7" }
|
||||
tree-sitter-heex = { git = "https://github.com/zed-industries/tree-sitter-heex", rev = "1dd45142fbb05562e35b2040c6129c9bca346592" }
|
||||
tree-sitter-diff = "0.1.0"
|
||||
tree-sitter-html = "0.20"
|
||||
tree-sitter-jsdoc = "0.23"
|
||||
tree-sitter-json = "0.23"
|
||||
@@ -477,11 +478,9 @@ wasmtime = { version = "24", default-features = false, features = [
|
||||
wasmtime-wasi = "24"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
version = "0.39"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[build]
|
||||
dockerfile = "Dockerfile-cross"
|
||||
@@ -13,9 +13,28 @@ ARG GITHUB_SHA
|
||||
|
||||
ENV GITHUB_SHA=$GITHUB_SHA
|
||||
|
||||
# Also add `cmake`, since we need it to build `wasmtime`.
|
||||
# At some point in the past 3 weeks, additional dependencies on `xkbcommon` and
|
||||
# `xkbcommon-x11` were introduced into collab.
|
||||
#
|
||||
# A `git bisect` points to this commit as being the culprit: `b8e6098f60e5dabe98fe8281f993858dacc04a55`.
|
||||
#
|
||||
# Now when we try to build collab for the Docker image, it fails with the following
|
||||
# error:
|
||||
#
|
||||
# ```
|
||||
# 985.3 = note: /usr/bin/ld: cannot find -lxkbcommon: No such file or directory
|
||||
# 985.3 /usr/bin/ld: cannot find -lxkbcommon-x11: No such file or directory
|
||||
# 985.3 collect2: error: ld returned 1 exit status
|
||||
# ```
|
||||
#
|
||||
# The last successful deploys were at:
|
||||
# - Staging: `4f408ec65a3867278322a189b4eb20f1ab51f508`
|
||||
# - Production: `fc4c533d0a8c489e5636a4249d2b52a80039fbd7`
|
||||
#
|
||||
# Installing these as a temporary workaround, but I think ideally we'd want to figure
|
||||
# out what caused them to be included in the first place.
|
||||
RUN apt-get update; \
|
||||
apt-get install -y --no-install-recommends cmake
|
||||
apt-get install -y --no-install-recommends libxkbcommon-dev libxkbcommon-x11-dev
|
||||
|
||||
RUN --mount=type=cache,target=./script/node_modules \
|
||||
--mount=type=cache,target=/usr/local/cargo/registry \
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG CROSS_BASE_IMAGE
|
||||
FROM ${CROSS_BASE_IMAGE}
|
||||
WORKDIR /app
|
||||
ARG TZ=Etc/UTC \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV CARGO_TERM_COLOR=always
|
||||
|
||||
COPY script/install-mold script/
|
||||
RUN ./script/install-mold "2.34.0"
|
||||
COPY script/remote-server script/
|
||||
RUN ./script/remote-server
|
||||
|
||||
COPY . .
|
||||
@@ -1,16 +0,0 @@
|
||||
.git
|
||||
.github
|
||||
**/.gitignore
|
||||
**/.gitkeep
|
||||
.gitattributes
|
||||
.mailmap
|
||||
**/target
|
||||
zed.xcworkspace
|
||||
.DS_Store
|
||||
compose.yml
|
||||
plugins/bin
|
||||
script/node_modules
|
||||
styles/node_modules
|
||||
crates/collab/static/styles.css
|
||||
vendor/bin
|
||||
assets/themes/
|
||||
@@ -20,7 +20,6 @@
|
||||
"bashrc": "terminal",
|
||||
"bmp": "image",
|
||||
"c": "c",
|
||||
"c++": "cpp",
|
||||
"cc": "cpp",
|
||||
"cjs": "javascript",
|
||||
"coffee": "coffeescript",
|
||||
@@ -28,7 +27,6 @@
|
||||
"cpp": "cpp",
|
||||
"css": "css",
|
||||
"csv": "storage",
|
||||
"cxx": "cpp",
|
||||
"cts": "typescript",
|
||||
"dart": "dart",
|
||||
"dat": "storage",
|
||||
@@ -68,13 +66,11 @@
|
||||
"heex": "elixir",
|
||||
"heic": "image",
|
||||
"heif": "image",
|
||||
"hh": "cpp",
|
||||
"hpp": "cpp",
|
||||
"hrl": "erlang",
|
||||
"hs": "haskell",
|
||||
"htm": "template",
|
||||
"html": "template",
|
||||
"hxx": "cpp",
|
||||
"ib": "storage",
|
||||
"ico": "image",
|
||||
"ini": "settings",
|
||||
|
||||
@@ -664,8 +664,7 @@
|
||||
"shift-up": "terminal::ScrollLineUp",
|
||||
"shift-down": "terminal::ScrollLineDown",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -395,7 +395,6 @@
|
||||
// Change the default action on `menu::Confirm` by setting the parameter
|
||||
// "alt-cmd-o": ["projects::OpenRecent", {"create_new_window": true }],
|
||||
"alt-cmd-o": "projects::OpenRecent",
|
||||
"ctrl-cmd-o": "projects::OpenRemote",
|
||||
"alt-cmd-b": "branches::OpenRecent",
|
||||
"ctrl-~": "workspace::NewTerminal",
|
||||
"cmd-s": "workspace::Save",
|
||||
@@ -679,8 +678,7 @@
|
||||
"cmd-home": "terminal::ScrollToTop",
|
||||
"cmd-end": "terminal::ScrollToBottom",
|
||||
"shift-home": "terminal::ScrollToTop",
|
||||
"shift-end": "terminal::ScrollToBottom",
|
||||
"ctrl-shift-space": "terminal::ToggleViMode"
|
||||
"shift-end": "terminal::ScrollToBottom"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -128,10 +128,6 @@
|
||||
"shift-m": "vim::WindowMiddle",
|
||||
"shift-l": "vim::WindowBottom",
|
||||
// z commands
|
||||
"z enter": ["workspace::SendKeystrokes", "z t ^"],
|
||||
"z -": ["workspace::SendKeystrokes", "z b ^"],
|
||||
"z ^": ["workspace::SendKeystrokes", "shift-h k z b ^"],
|
||||
"z +": ["workspace::SendKeystrokes", "shift-l j z t ^"],
|
||||
"z t": "editor::ScrollCursorTop",
|
||||
"z z": "editor::ScrollCursorCenter",
|
||||
"z .": ["workspace::SendKeystrokes", "z z ^"],
|
||||
@@ -256,7 +252,6 @@
|
||||
"@": ["vim::PushOperator", "ReplayRegister"],
|
||||
"ctrl-pagedown": "pane::ActivateNextItem",
|
||||
"ctrl-pageup": "pane::ActivatePrevItem",
|
||||
"insert": "vim::InsertBefore",
|
||||
// tree-sitter related commands
|
||||
"[ x": "editor::SelectLargerSyntaxNode",
|
||||
"] x": "editor::SelectSmallerSyntaxNode",
|
||||
@@ -339,8 +334,7 @@
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"ctrl-r": ["vim::PushOperator", "Register"],
|
||||
"insert": "vim::ToggleReplace"
|
||||
"ctrl-r": ["vim::PushOperator", "Register"]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -359,8 +353,7 @@
|
||||
"ctrl-k": ["vim::PushOperator", { "Digraph": {} }],
|
||||
"backspace": "vim::UndoReplace",
|
||||
"tab": "vim::Tab",
|
||||
"enter": "vim::Enter",
|
||||
"insert": "vim::InsertBefore"
|
||||
"enter": "vim::Enter"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -118,8 +118,8 @@
|
||||
// "bar"
|
||||
// 2. A block that surrounds the following character
|
||||
// "block"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -494,14 +494,7 @@
|
||||
// Position of the close button on the editor tabs.
|
||||
"close_position": "right",
|
||||
// Whether to show the file icon for a tab.
|
||||
"file_icons": false,
|
||||
// What to do after closing the current tab.
|
||||
//
|
||||
// 1. Activate the tab that was open previously (default)
|
||||
// "History"
|
||||
// 2. Activate the neighbour tab (prefers the right one, if present)
|
||||
// "Neighbour"
|
||||
"activate_on_close": "history"
|
||||
"file_icons": false
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -691,8 +684,8 @@
|
||||
// "block"
|
||||
// 2. A vertical bar
|
||||
// "bar"
|
||||
// 3. An underline / underscore that runs along the following character
|
||||
// "underline"
|
||||
// 3. An underline that runs along the following character
|
||||
// "underscore"
|
||||
// 4. A box drawn around the following character
|
||||
// "hollow"
|
||||
//
|
||||
@@ -824,7 +817,6 @@
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
"Astro": {
|
||||
"language_servers": ["astro-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
@@ -848,13 +840,6 @@
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Dart": {
|
||||
"tab_size": 2
|
||||
},
|
||||
"Diff": {
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
// Server-specific settings
|
||||
//
|
||||
// For a full list of overridable settings, and general information on settings,
|
||||
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
|
||||
{
|
||||
"lsp": {}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ use gpui::{
|
||||
use language::{
|
||||
LanguageRegistry, LanguageServerBinaryStatus, LanguageServerId, LanguageServerName,
|
||||
};
|
||||
use project::{EnvironmentErrorMessage, LanguageServerProgress, Project, WorktreeId};
|
||||
use project::{LanguageServerProgress, Project};
|
||||
use smallvec::SmallVec;
|
||||
use std::{cmp::Reverse, fmt::Write, sync::Arc, time::Duration};
|
||||
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle};
|
||||
@@ -101,7 +101,6 @@ impl ActivityIndicator {
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
buffer.set_capability(language::Capability::ReadOnly, cx);
|
||||
})?;
|
||||
workspace.update(&mut cx, |workspace, cx| {
|
||||
workspace.add_item_to_active_pane(
|
||||
@@ -176,31 +175,7 @@ impl ActivityIndicator {
|
||||
.flatten()
|
||||
}
|
||||
|
||||
fn pending_environment_errors<'a>(
|
||||
&'a self,
|
||||
cx: &'a AppContext,
|
||||
) -> impl Iterator<Item = (&'a WorktreeId, &'a EnvironmentErrorMessage)> {
|
||||
self.project.read(cx).shell_environment_errors(cx)
|
||||
}
|
||||
|
||||
fn content_to_render(&mut self, cx: &mut ViewContext<Self>) -> Option<Content> {
|
||||
// Show if any direnv calls failed
|
||||
if let Some((&worktree_id, error)) = self.pending_environment_errors(cx).next() {
|
||||
return Some(Content {
|
||||
icon: Some(
|
||||
Icon::new(IconName::Warning)
|
||||
.size(IconSize::Small)
|
||||
.into_any_element(),
|
||||
),
|
||||
message: error.0.clone(),
|
||||
on_click: Some(Arc::new(move |this, cx| {
|
||||
this.project.update(cx, |project, cx| {
|
||||
project.remove_environment_error(cx, worktree_id);
|
||||
});
|
||||
cx.dispatch_action(Box::new(workspace::OpenLog));
|
||||
})),
|
||||
});
|
||||
}
|
||||
// Show any language server has pending activity.
|
||||
let mut pending_work = self.pending_language_server_work(cx);
|
||||
if let Some(PendingWork {
|
||||
|
||||
@@ -26,3 +26,6 @@ serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -521,10 +521,6 @@ pub struct Usage {
|
||||
pub input_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::{
|
||||
PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata,
|
||||
SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, WorkflowStepResolution,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use client::{proto, Client, Status};
|
||||
@@ -697,9 +697,7 @@ impl AssistantPanel {
|
||||
log::error!("no context found with ID: {}", context_id.to_proto());
|
||||
return;
|
||||
};
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
|
||||
let assistant_panel = cx.view().downgrade();
|
||||
let editor = cx.new_view(|cx| {
|
||||
@@ -973,8 +971,7 @@ impl AssistantPanel {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let workspace = this.workspace.clone();
|
||||
let project = this.project.clone();
|
||||
let lsp_adapter_delegate =
|
||||
make_lsp_adapter_delegate(&project, cx).log_err().flatten();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err();
|
||||
|
||||
let fs = this.fs.clone();
|
||||
let project = this.project.clone();
|
||||
@@ -1004,9 +1001,7 @@ impl AssistantPanel {
|
||||
None
|
||||
} else {
|
||||
let context = self.context_store.update(cx, |store, cx| store.create(cx));
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
|
||||
let assistant_panel = cx.view().downgrade();
|
||||
let editor = cx.new_view(|cx| {
|
||||
@@ -1212,7 +1207,7 @@ impl AssistantPanel {
|
||||
let project = self.project.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err().flatten();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&project, cx).log_err();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let context = context.await?;
|
||||
@@ -1259,9 +1254,7 @@ impl AssistantPanel {
|
||||
.update(cx, |store, cx| store.open_remote_context(id, cx));
|
||||
let fs = self.fs.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx)
|
||||
.log_err()
|
||||
.flatten();
|
||||
let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx).log_err();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let context = context.await?;
|
||||
@@ -1503,13 +1496,6 @@ struct WorkflowAssist {
|
||||
|
||||
type MessageHeader = MessageMetadata;
|
||||
|
||||
#[derive(Clone)]
|
||||
enum AssistError {
|
||||
PaymentRequired,
|
||||
MaxMonthlySpendReached,
|
||||
Message(SharedString),
|
||||
}
|
||||
|
||||
pub struct ContextEditor {
|
||||
context: Model<Context>,
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -1528,7 +1514,7 @@ pub struct ContextEditor {
|
||||
workflow_steps: HashMap<Range<language::Anchor>, WorkflowStepViewState>,
|
||||
active_workflow_step: Option<ActiveWorkflowStep>,
|
||||
assistant_panel: WeakView<AssistantPanel>,
|
||||
last_error: Option<AssistError>,
|
||||
error_message: Option<SharedString>,
|
||||
show_accept_terms: bool,
|
||||
pub(crate) slash_menu_handle:
|
||||
PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
|
||||
@@ -1567,7 +1553,7 @@ impl ContextEditor {
|
||||
editor.set_show_runnables(false, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_completion_provider(Some(Box::new(completion_provider)));
|
||||
editor.set_completion_provider(Box::new(completion_provider));
|
||||
editor.set_collaboration_hub(Box::new(project.clone()));
|
||||
editor
|
||||
});
|
||||
@@ -1599,7 +1585,7 @@ impl ContextEditor {
|
||||
workflow_steps: HashMap::default(),
|
||||
active_workflow_step: None,
|
||||
assistant_panel,
|
||||
last_error: None,
|
||||
error_message: None,
|
||||
show_accept_terms: false,
|
||||
slash_menu_handle: Default::default(),
|
||||
dragged_file_worktrees: Vec::new(),
|
||||
@@ -1643,7 +1629,7 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
if !self.apply_active_workflow_step(cx) {
|
||||
self.last_error = None;
|
||||
self.error_message = None;
|
||||
self.send_to_model(cx);
|
||||
cx.notify();
|
||||
}
|
||||
@@ -1793,7 +1779,7 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
|
||||
self.last_error = None;
|
||||
self.error_message = None;
|
||||
|
||||
if self
|
||||
.context
|
||||
@@ -2298,13 +2284,7 @@ impl ContextEditor {
|
||||
}
|
||||
ContextEvent::Operation(_) => {}
|
||||
ContextEvent::ShowAssistError(error_message) => {
|
||||
self.last_error = Some(AssistError::Message(error_message.clone()));
|
||||
}
|
||||
ContextEvent::ShowPaymentRequiredError => {
|
||||
self.last_error = Some(AssistError::PaymentRequired);
|
||||
}
|
||||
ContextEvent::ShowMaxMonthlySpendReachedError => {
|
||||
self.last_error = Some(AssistError::MaxMonthlySpendReached);
|
||||
self.error_message = Some(error_message.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4318,154 +4298,6 @@ impl ContextEditor {
|
||||
focus_handle.dispatch_action(&Assist, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
|
||||
let last_error = self.last_error.as_ref()?;
|
||||
|
||||
Some(
|
||||
div()
|
||||
.absolute()
|
||||
.right_3()
|
||||
.bottom_12()
|
||||
.max_w_96()
|
||||
.py_2()
|
||||
.px_3()
|
||||
.elevation_2(cx)
|
||||
.occlude()
|
||||
.child(match last_error {
|
||||
AssistError::PaymentRequired => self.render_payment_required_error(cx),
|
||||
AssistError::MaxMonthlySpendReached => {
|
||||
self.render_max_monthly_spend_reached_error(cx)
|
||||
}
|
||||
AssistError::Message(error_message) => {
|
||||
self.render_assist_error(error_message, cx)
|
||||
}
|
||||
})
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_payment_required_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
|
||||
const SUBSCRIBE_URL: &str = "https://zed.dev/api/billing/initiate_checkout";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.open_url(SUBSCRIBE_URL);
|
||||
cx.notify();
|
||||
},
|
||||
)))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_max_monthly_spend_reached_error(&self, cx: &mut ViewContext<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
|
||||
const ACCOUNT_URL: &str = "https://zed.dev/account";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.open_url(ACCOUNT_URL);
|
||||
cx.notify();
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_assist_error(
|
||||
&self,
|
||||
error_message: &SharedString,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> AnyElement {
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(
|
||||
Label::new("Error interacting with language model")
|
||||
.weight(FontWeight::MEDIUM),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message.clone())),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the contents of the *outermost* fenced code block that contains the given offset.
|
||||
@@ -4602,7 +4434,48 @@ impl Render for ContextEditor {
|
||||
.child(element),
|
||||
)
|
||||
})
|
||||
.children(self.render_last_error(cx))
|
||||
.when_some(self.error_message.clone(), |this, error_message| {
|
||||
this.child(
|
||||
div()
|
||||
.absolute()
|
||||
.right_3()
|
||||
.bottom_12()
|
||||
.max_w_96()
|
||||
.py_2()
|
||||
.px_3()
|
||||
.elevation_2(cx)
|
||||
.occlude()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(
|
||||
Label::new("Error interacting with language model")
|
||||
.weight(FontWeight::MEDIUM),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message)),
|
||||
)
|
||||
.child(h_flex().justify_end().mt_1().child(
|
||||
Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, cx| {
|
||||
this.error_message = None;
|
||||
cx.notify();
|
||||
},
|
||||
)),
|
||||
)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
h_flex().w_full().relative().child(
|
||||
h_flex()
|
||||
@@ -5632,21 +5505,22 @@ fn render_docs_slash_command_trailer(
|
||||
fn make_lsp_adapter_delegate(
|
||||
project: &Model<Project>,
|
||||
cx: &mut AppContext,
|
||||
) -> Result<Option<Arc<dyn LspAdapterDelegate>>> {
|
||||
) -> Result<Arc<dyn LspAdapterDelegate>> {
|
||||
project.update(cx, |project, cx| {
|
||||
// TODO: Find the right worktree.
|
||||
let Some(worktree) = project.worktrees(cx).next() else {
|
||||
return Ok(None::<Arc<dyn LspAdapterDelegate>>);
|
||||
};
|
||||
let worktree = project
|
||||
.worktrees(cx)
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no worktrees when constructing LocalLspAdapterDelegate"))?;
|
||||
let http_client = project.client().http_client().clone();
|
||||
project.lsp_store().update(cx, |lsp_store, cx| {
|
||||
Ok(Some(LocalLspAdapterDelegate::new(
|
||||
Ok(LocalLspAdapterDelegate::new(
|
||||
lsp_store,
|
||||
&worktree,
|
||||
http_client,
|
||||
project.fs().clone(),
|
||||
cx,
|
||||
) as Arc<dyn LspAdapterDelegate>))
|
||||
) as Arc<dyn LspAdapterDelegate>)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ use gpui::{
|
||||
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
use language_model::{
|
||||
provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError},
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||
@@ -295,8 +294,6 @@ impl ContextOperation {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContextEvent {
|
||||
ShowAssistError(SharedString),
|
||||
ShowPaymentRequiredError,
|
||||
ShowMaxMonthlySpendReachedError,
|
||||
MessagesEdited,
|
||||
SummaryChanged,
|
||||
StreamedCompletion,
|
||||
@@ -2115,36 +2112,25 @@ impl Context {
|
||||
let result = stream_completion.await;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let error_message = if let Some(error) = result.as_ref().err() {
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ContextEvent::ShowPaymentRequiredError);
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
let error_message = result
|
||||
.as_ref()
|
||||
.err()
|
||||
.map(|error| error.to_string().trim().to_string());
|
||||
|
||||
if let Some(error_message) = error_message.as_ref() {
|
||||
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
||||
error_message.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
if let Some(error_message) = error_message.as_ref() {
|
||||
metadata.status =
|
||||
MessageStatus::Error(SharedString::from(error_message.clone()));
|
||||
} else {
|
||||
let error_message = error.to_string().trim().to_string();
|
||||
cx.emit(ContextEvent::ShowAssistError(SharedString::from(
|
||||
error_message.clone(),
|
||||
)));
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status =
|
||||
MessageStatus::Error(SharedString::from(error_message.clone()));
|
||||
});
|
||||
Some(error_message)
|
||||
}
|
||||
} else {
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Done;
|
||||
});
|
||||
None
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(telemetry) = this.telemetry.as_ref() {
|
||||
let language_name = this
|
||||
@@ -2160,7 +2146,7 @@ impl Context {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -267,7 +267,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: buffer.language().map(|language| language.name().to_proto()),
|
||||
language_name: buffer.language().map(|language| language.name()),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -788,7 +788,7 @@ impl InlineAssistant {
|
||||
model_provider: model.provider_id().to_string(),
|
||||
response_latency: None,
|
||||
error_message: None,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -2278,7 +2278,7 @@ impl InlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
@@ -2954,7 +2954,7 @@ impl CodegenAlternative {
|
||||
model_provider: model_provider_id.to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
language_name: language_name.map(|name| name.to_proto()),
|
||||
language_name,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -521,9 +521,9 @@ impl PromptLibrary {
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_use_modal_editing(false);
|
||||
editor.set_current_line_highlight(Some(CurrentLineHighlight::None));
|
||||
editor.set_completion_provider(Some(Box::new(
|
||||
editor.set_completion_provider(Box::new(
|
||||
SlashCommandCompletionProvider::new(None, None),
|
||||
)));
|
||||
));
|
||||
if focus {
|
||||
editor.focus(cx);
|
||||
}
|
||||
|
||||
@@ -38,10 +38,7 @@ impl Settings for SlashCommandSettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _cx: &mut AppContext) -> Result<Self> {
|
||||
SettingsSources::<Self::FileContent>::json_merge_with(
|
||||
[sources.default]
|
||||
.into_iter()
|
||||
.chain(sources.user)
|
||||
.chain(sources.server),
|
||||
[sources.default].into_iter().chain(sources.user),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ impl TerminalInlineAssist {
|
||||
struct InlineAssistantError;
|
||||
|
||||
let id =
|
||||
NotificationId::composite::<InlineAssistantError>(
|
||||
NotificationId::identified::<InlineAssistantError>(
|
||||
assist_id.0,
|
||||
);
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ impl Settings for AutoUpdateSetting {
|
||||
type FileContent = Option<AutoUpdateSettingContent>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
let auto_update = [sources.server, sources.release_channel, sources.user]
|
||||
let auto_update = [sources.release_channel, sources.user]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
@@ -464,7 +464,6 @@ impl AutoUpdater {
|
||||
smol::fs::create_dir_all(&platform_dir).await.ok();
|
||||
|
||||
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
|
||||
|
||||
if smol::fs::metadata(&version_path).await.is_err() {
|
||||
log::info!("downloading zed-remote-server {os} {arch}");
|
||||
download_remote_server_binary(&version_path, release, client, cx).await?;
|
||||
|
||||
@@ -1178,7 +1178,7 @@ impl Room {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.joined_projects.retain(|project| {
|
||||
if let Some(project) = project.upgrade() {
|
||||
!project.read(cx).is_disconnected(cx)
|
||||
!project.read(cx).is_disconnected()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -58,32 +58,27 @@ struct Args {
|
||||
dev_server_token: Option<String>,
|
||||
}
|
||||
|
||||
fn parse_path_with_position(argument_str: &str) -> anyhow::Result<String> {
|
||||
let canonicalized = match Path::new(argument_str).canonicalize() {
|
||||
Ok(existing_path) => PathWithPosition::from_path(existing_path),
|
||||
Err(_) => {
|
||||
let path = PathWithPosition::parse_str(argument_str);
|
||||
let curdir = env::current_dir().context("reteiving current directory")?;
|
||||
path.map_path(|path| match fs::canonicalize(&path) {
|
||||
Ok(path) => Ok(path),
|
||||
Err(e) => {
|
||||
if let Some(mut parent) = path.parent() {
|
||||
if parent == Path::new("") {
|
||||
parent = &curdir
|
||||
}
|
||||
match fs::canonicalize(parent) {
|
||||
Ok(parent) => Ok(parent.join(path.file_name().unwrap())),
|
||||
Err(_) => Err(e),
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
fn parse_path_with_position(argument_str: &str) -> Result<String, std::io::Error> {
|
||||
let path = PathWithPosition::parse_str(argument_str);
|
||||
let curdir = env::current_dir()?;
|
||||
|
||||
let canonicalized = path.map_path(|path| match fs::canonicalize(&path) {
|
||||
Ok(path) => Ok(path),
|
||||
Err(e) => {
|
||||
if let Some(mut parent) = path.parent() {
|
||||
if parent == Path::new("") {
|
||||
parent = &curdir
|
||||
}
|
||||
})
|
||||
match fs::canonicalize(parent) {
|
||||
Ok(parent) => Ok(parent.join(path.file_name().unwrap())),
|
||||
Err(_) => Err(e),
|
||||
}
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
.with_context(|| format!("parsing as path with position {argument_str}"))?,
|
||||
};
|
||||
Ok(canonicalized.to_string(|path| path.to_string_lossy().to_string()))
|
||||
})?;
|
||||
Ok(canonicalized.to_string(|path| path.display().to_string()))
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
|
||||
@@ -18,6 +18,7 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tls = "0.13"
|
||||
async-tungstenite = { workspace = true, features = ["async-std", "async-tls"] }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
@@ -34,8 +35,6 @@ postage.workspace = true
|
||||
rand.workspace = true
|
||||
release_channel.workspace = true
|
||||
rpc = { workspace = true, features = ["gpui"] }
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -141,7 +141,6 @@ impl Settings for ProxySettings {
|
||||
Ok(Self {
|
||||
proxy: sources
|
||||
.user
|
||||
.or(sources.server)
|
||||
.and_then(|value| value.proxy.clone())
|
||||
.or(sources.default.proxy.clone()),
|
||||
})
|
||||
@@ -473,21 +472,15 @@ impl settings::Settings for TelemetrySettings {
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
Ok(Self {
|
||||
diagnostics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.diagnostics)
|
||||
.unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
diagnostics: sources.user.as_ref().and_then(|v| v.diagnostics).unwrap_or(
|
||||
sources
|
||||
.default
|
||||
.diagnostics
|
||||
.ok_or_else(Self::missing_default)?,
|
||||
),
|
||||
metrics: sources
|
||||
.user
|
||||
.as_ref()
|
||||
.or(sources.server.as_ref())
|
||||
.and_then(|v| v.metrics)
|
||||
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
|
||||
})
|
||||
@@ -1144,31 +1137,13 @@ impl Client {
|
||||
|
||||
match url_scheme {
|
||||
Https => {
|
||||
let client_config = {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
let root_certs = rustls_native_certs::load_native_certs();
|
||||
for error in root_certs.errors {
|
||||
log::warn!("error loading native certs: {:?}", error);
|
||||
}
|
||||
root_store.add_parsable_certificates(
|
||||
&root_certs
|
||||
.certs
|
||||
.into_iter()
|
||||
.map(|cert| cert.as_ref().to_owned())
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
let (stream, _) =
|
||||
async_tungstenite::async_tls::client_async_tls_with_connector(
|
||||
request,
|
||||
stream,
|
||||
Some(client_config.into()),
|
||||
Some(async_tls::TlsConnector::from(
|
||||
http_client::TLS_CONFIG.clone(),
|
||||
)),
|
||||
)
|
||||
.await?;
|
||||
Ok(Connection::new(
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
mod event_coalescer;
|
||||
|
||||
use crate::{ChannelId, TelemetrySettings};
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::Future;
|
||||
use gpui::{AppContext, BackgroundExecutor, Task};
|
||||
use http_client::{self, AsyncBody, HttpClient, HttpClientWithUrl, Method, Request};
|
||||
use http_client::{self, HttpClient, HttpClientWithUrl, Method};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use release_channel::ReleaseChannel;
|
||||
@@ -365,7 +364,6 @@ impl Telemetry {
|
||||
operation: &'static str,
|
||||
copilot_enabled: bool,
|
||||
copilot_enabled_for_language: bool,
|
||||
is_via_ssh: bool,
|
||||
) {
|
||||
let event = Event::Editor(EditorEvent {
|
||||
file_extension,
|
||||
@@ -373,7 +371,6 @@ impl Telemetry {
|
||||
operation: operation.into(),
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
is_via_ssh,
|
||||
});
|
||||
|
||||
self.report_event(event)
|
||||
@@ -459,7 +456,7 @@ impl Telemetry {
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str, is_via_ssh: bool) {
|
||||
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str) {
|
||||
let mut state = self.state.lock();
|
||||
let period_data = state.event_coalescer.log_event(environment);
|
||||
drop(state);
|
||||
@@ -468,7 +465,6 @@ impl Telemetry {
|
||||
let event = Event::Edit(EditEvent {
|
||||
duration: end.timestamp_millis() - start.timestamp_millis(),
|
||||
environment: environment.to_string(),
|
||||
is_via_ssh,
|
||||
});
|
||||
|
||||
self.report_event(event);
|
||||
@@ -489,7 +485,7 @@ impl Telemetry {
|
||||
worktree_id: WorktreeId,
|
||||
updated_entries_set: &UpdatedEntriesSet,
|
||||
) {
|
||||
let project_type_names: Vec<String> = {
|
||||
let project_names: Vec<String> = {
|
||||
let mut state = self.state.lock();
|
||||
state
|
||||
.worktree_id_map
|
||||
@@ -525,8 +521,8 @@ impl Telemetry {
|
||||
};
|
||||
|
||||
// Done on purpose to avoid calling `self.state.lock()` multiple times
|
||||
for project_type_name in project_type_names {
|
||||
self.report_app_event(format!("open {} project", project_type_name));
|
||||
for project_name in project_names {
|
||||
self.report_app_event(format!("open {} project", project_name));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,29 +594,6 @@ impl Telemetry {
|
||||
self.state.lock().is_staff
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
self: &Arc<Self>,
|
||||
// We take in the JSON bytes buffer so we can reuse the existing allocation.
|
||||
mut json_bytes: Vec<u8>,
|
||||
event_request: EventRequestBody,
|
||||
) -> Result<Request<AsyncBody>> {
|
||||
json_bytes.clear();
|
||||
serde_json::to_writer(&mut json_bytes, &event_request)?;
|
||||
|
||||
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
|
||||
|
||||
Ok(Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(
|
||||
self.http_client
|
||||
.build_zed_api_url("/telemetry/events", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("x-zed-checksum", checksum)
|
||||
.body(json_bytes.into())?)
|
||||
}
|
||||
|
||||
pub fn flush_events(self: &Arc<Self>) {
|
||||
let mut state = self.state.lock();
|
||||
state.first_event_date_time = None;
|
||||
@@ -647,10 +620,10 @@ impl Telemetry {
|
||||
}
|
||||
}
|
||||
|
||||
let request_body = {
|
||||
{
|
||||
let state = this.state.lock();
|
||||
|
||||
EventRequestBody {
|
||||
let request_body = EventRequestBody {
|
||||
system_id: state.system_id.as_deref().map(Into::into),
|
||||
installation_id: state.installation_id.as_deref().map(Into::into),
|
||||
session_id: state.session_id.clone(),
|
||||
@@ -663,11 +636,25 @@ impl Telemetry {
|
||||
|
||||
release_channel: state.release_channel.map(Into::into),
|
||||
events,
|
||||
}
|
||||
};
|
||||
};
|
||||
json_bytes.clear();
|
||||
serde_json::to_writer(&mut json_bytes, &request_body)?;
|
||||
}
|
||||
|
||||
let request = this.build_request(json_bytes, request_body)?;
|
||||
let response = this.http_client.send(request).await?;
|
||||
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(
|
||||
this.http_client
|
||||
.build_zed_api_url("/telemetry/events", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "text/plain")
|
||||
.header("x-zed-checksum", checksum)
|
||||
.body(json_bytes.into());
|
||||
|
||||
let response = this.http_client.send(request?).await?;
|
||||
if response.status() != 200 {
|
||||
log::error!("Failed to send events: HTTP {:?}", response.status());
|
||||
}
|
||||
|
||||
@@ -216,11 +216,10 @@ impl fmt::Debug for Global {
|
||||
if timestamp.replica_id > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
if timestamp.replica_id == LOCAL_BRANCH_REPLICA_ID {
|
||||
write!(f, "<branch>: {}", timestamp.value)?;
|
||||
} else {
|
||||
write!(f, "{}: {}", timestamp.replica_id, timestamp.value)?;
|
||||
}
|
||||
write!(f, "{}: {}", timestamp.replica_id, timestamp.value)?;
|
||||
}
|
||||
if self.local_branch_value > 0 {
|
||||
write!(f, "<branch>: {}", self.local_branch_value)?;
|
||||
}
|
||||
write!(f, "}}")
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ clickhouse.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
@@ -67,7 +66,7 @@ telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tokio.workspace = true
|
||||
toml.workspace = true
|
||||
tower = "0.4"
|
||||
tower-http = { workspace = true, features = ["trace"] }
|
||||
|
||||
@@ -422,15 +422,6 @@ CREATE TABLE dev_server_projects (
|
||||
paths TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
create table if not exists billing_preferences (
|
||||
id serial primary key,
|
||||
created_at timestamp without time zone not null default now(),
|
||||
user_id integer not null references users(id) on delete cascade,
|
||||
max_monthly_llm_usage_spending_in_cents integer not null
|
||||
);
|
||||
|
||||
create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id);
|
||||
@@ -1,11 +0,0 @@
|
||||
alter table models
|
||||
add column price_per_million_cache_creation_input_tokens integer not null default 0,
|
||||
add column price_per_million_cache_read_input_tokens integer not null default 0;
|
||||
|
||||
alter table usages
|
||||
add column cache_creation_input_tokens_this_month bigint not null default 0,
|
||||
add column cache_read_input_tokens_this_month bigint not null default 0;
|
||||
|
||||
alter table lifetime_usages
|
||||
add column cache_creation_input_tokens bigint not null default 0,
|
||||
add column cache_read_input_tokens bigint not null default 0;
|
||||
@@ -1,3 +0,0 @@
|
||||
alter table usages
|
||||
drop column cache_creation_input_tokens_this_month,
|
||||
drop column cache_read_input_tokens_this_month;
|
||||
@@ -1,13 +0,0 @@
|
||||
create table monthly_usages (
|
||||
id serial primary key,
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
month integer not null,
|
||||
year integer not null,
|
||||
input_tokens bigint not null default 0,
|
||||
cache_creation_input_tokens bigint not null default 0,
|
||||
cache_read_input_tokens bigint not null default 0,
|
||||
output_tokens bigint not null default 0
|
||||
);
|
||||
|
||||
create unique index uix_monthly_usages_on_user_id_model_id_month_year on monthly_usages (user_id, model_id, month, year);
|
||||
@@ -1,12 +0,0 @@
|
||||
create table billing_events (
|
||||
id serial primary key,
|
||||
idempotency_key uuid not null default gen_random_uuid(),
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
input_tokens bigint not null default 0,
|
||||
input_cache_creation_tokens bigint not null default 0,
|
||||
input_cache_read_tokens bigint not null default 0,
|
||||
output_tokens bigint not null default 0
|
||||
);
|
||||
|
||||
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
|
||||
@@ -1,3 +1,7 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use axum::{
|
||||
extract::{self, Query},
|
||||
@@ -5,43 +9,29 @@ use axum::{
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use collections::HashSet;
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use stripe::{
|
||||
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
||||
Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::rpc::ResultExt as _;
|
||||
use crate::{
|
||||
db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
use crate::{
|
||||
db::{billing_subscription::StripeSubscriptionStatus, UserId},
|
||||
llm::db::LlmDatabase,
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
};
|
||||
use crate::{AppState, Error, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/billing/preferences",
|
||||
get(get_billing_preferences).put(update_billing_preferences),
|
||||
)
|
||||
.route(
|
||||
"/billing/subscriptions",
|
||||
get(list_billing_subscriptions).post(create_billing_subscription),
|
||||
@@ -52,85 +42,6 @@ pub fn router() -> Router {
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetBillingPreferencesParams {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn get_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetBillingPreferencesParams>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(params.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let preferences = app.db.get_billing_preferences(user.id).await?;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpdateBillingPreferencesBody {
|
||||
github_user_id: i32,
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn update_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(body.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let billing_preferences =
|
||||
if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
|
||||
app.db
|
||||
.update_billing_preferences(
|
||||
user.id,
|
||||
&UpdateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
body.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
},
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
app.db
|
||||
.create_billing_preferences(
|
||||
user.id,
|
||||
&crate::db::CreateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents: body
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListBillingSubscriptionsParams {
|
||||
github_user_id: i32,
|
||||
@@ -168,7 +79,7 @@ async fn list_billing_subscriptions(
|
||||
.into_iter()
|
||||
.map(|subscription| BillingSubscriptionJson {
|
||||
id: subscription.id,
|
||||
name: "Zed LLM Usage".to_string(),
|
||||
name: "Zed Pro".to_string(),
|
||||
status: subscription.stripe_subscription_status,
|
||||
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
|
||||
cancel_at
|
||||
@@ -203,22 +114,12 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
let Some((stripe_client, stripe_price_id)) = app
|
||||
.stripe_client
|
||||
.clone()
|
||||
.zip(app.config.stripe_price_id.clone())
|
||||
else {
|
||||
log::error!("failed to retrieve Stripe client or price ID");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
@@ -242,14 +143,26 @@ async fn create_billing_subscription(
|
||||
customer.id
|
||||
};
|
||||
|
||||
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
let checkout_session_url = stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?;
|
||||
let checkout_session = {
|
||||
let mut params = CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(user.github_login.as_str());
|
||||
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
|
||||
price: Some(stripe_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
params.success_url = Some(&success_url);
|
||||
|
||||
CheckoutSession::create(&stripe_client, params).await?
|
||||
};
|
||||
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
checkout_session_url,
|
||||
checkout_session_url: checkout_session
|
||||
.url
|
||||
.ok_or_else(|| anyhow!("no checkout session URL"))?,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -718,73 +631,3 @@ async fn find_or_create_billing_customer(
|
||||
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.trace_err();
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn sync_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
.map(|(event, _)| event.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||
|
||||
for (event, model) in events {
|
||||
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||
stripe_subscriptions.get(&event.user_id)
|
||||
else {
|
||||
tracing::warn!(
|
||||
user_id = event.user_id.0,
|
||||
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||
.stripe_subscription_id
|
||||
.parse()
|
||||
.context("failed to parse stripe subscription id from db")?;
|
||||
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||
.stripe_customer_id
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -429,6 +429,8 @@ pub async fn post_events(
|
||||
country_code.clone(),
|
||||
checksum_matched,
|
||||
)),
|
||||
// Needed for clients sending old copilot_event types
|
||||
Event::Copilot(_) => {}
|
||||
Event::InlineCompletion(event) => {
|
||||
to_upload
|
||||
.inline_completion_events
|
||||
@@ -670,13 +672,13 @@ pub struct EditorEventRow {
|
||||
time: i64,
|
||||
copilot_enabled: bool,
|
||||
copilot_enabled_for_language: bool,
|
||||
historical_event: bool,
|
||||
architecture: String,
|
||||
is_staff: Option<bool>,
|
||||
major: Option<i32>,
|
||||
minor: Option<i32>,
|
||||
patch: Option<i32>,
|
||||
checksum_matched: bool,
|
||||
is_via_ssh: bool,
|
||||
}
|
||||
|
||||
impl EditorEventRow {
|
||||
@@ -717,7 +719,7 @@ impl EditorEventRow {
|
||||
country_code: country_code.unwrap_or("XX".to_string()),
|
||||
region_code: "".to_string(),
|
||||
city: "".to_string(),
|
||||
is_via_ssh: event.is_via_ssh,
|
||||
historical_event: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1261,7 +1263,6 @@ pub struct EditEventRow {
|
||||
period_start: i64,
|
||||
period_end: i64,
|
||||
environment: String,
|
||||
is_via_ssh: bool,
|
||||
}
|
||||
|
||||
impl EditEventRow {
|
||||
@@ -1295,7 +1296,6 @@ impl EditEventRow {
|
||||
period_start: period_start.timestamp_millis(),
|
||||
period_end: period_end.timestamp_millis(),
|
||||
environment: event.environment,
|
||||
is_via_ssh: event.is_via_ssh,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
/// A number of cents.
|
||||
#[derive(
|
||||
Debug,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Hash,
|
||||
Clone,
|
||||
Copy,
|
||||
derive_more::Add,
|
||||
derive_more::AddAssign,
|
||||
)]
|
||||
pub struct Cents(pub u32);
|
||||
|
||||
impl Cents {
|
||||
pub const ZERO: Self = Self(0);
|
||||
|
||||
pub const fn new(cents: u32) -> Self {
|
||||
Self(cents)
|
||||
}
|
||||
|
||||
pub const fn from_dollars(dollars: u32) -> Self {
|
||||
Self(dollars * 100)
|
||||
}
|
||||
|
||||
pub fn saturating_sub(self, other: Cents) -> Self {
|
||||
Self(self.0.saturating_sub(other.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cents_new() {
|
||||
assert_eq!(Cents::new(50), Cents(50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_from_dollars() {
|
||||
assert_eq!(Cents::from_dollars(1), Cents(100));
|
||||
assert_eq!(Cents::from_dollars(5), Cents(500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_zero() {
|
||||
assert_eq!(Cents::ZERO, Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add() {
|
||||
assert_eq!(Cents(50) + Cents(30), Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_add_assign() {
|
||||
let mut cents = Cents(50);
|
||||
cents += Cents(30);
|
||||
assert_eq!(cents, Cents(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_saturating_sub() {
|
||||
assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20));
|
||||
assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cents_ordering() {
|
||||
assert!(Cents(50) > Cents(30));
|
||||
assert!(Cents(30) < Cents(50));
|
||||
assert_eq!(Cents(50), Cents(50));
|
||||
}
|
||||
}
|
||||
@@ -42,9 +42,6 @@ pub use tests::TestDb;
|
||||
|
||||
pub use ids::*;
|
||||
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
|
||||
pub use queries::billing_preferences::{
|
||||
CreateBillingPreferencesParams, UpdateBillingPreferencesParams,
|
||||
};
|
||||
pub use queries::billing_subscriptions::{
|
||||
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
|
||||
@@ -72,7 +72,6 @@ macro_rules! id_type {
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BillingCustomerId);
|
||||
id_type!(BillingSubscriptionId);
|
||||
id_type!(BillingPreferencesId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
|
||||
@@ -2,7 +2,6 @@ use super::*;
|
||||
|
||||
pub mod access_tokens;
|
||||
pub mod billing_customers;
|
||||
pub mod billing_preferences;
|
||||
pub mod billing_subscriptions;
|
||||
pub mod buffers;
|
||||
pub mod channels;
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Returns the billing preferences for the given user, if they exist.
|
||||
pub async fn get_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Creates new billing preferences for the given user.
|
||||
pub async fn create_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &CreateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
params.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the billing preferences for the given user.
|
||||
pub async fn update_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
params: &UpdateBillingPreferencesParams,
|
||||
) -> Result<billing_preference::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let preferences = billing_preference::Entity::update_many()
|
||||
.set(billing_preference::ActiveModel {
|
||||
max_monthly_llm_usage_spending_in_cents: params
|
||||
.max_monthly_llm_usage_spending_in_cents
|
||||
.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(preferences
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("billing preferences not found"))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -112,37 +112,6 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_billing_subscriptions(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
|
||||
@@ -838,7 +838,6 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect(),
|
||||
dev_server_project_id: project.dev_server_project_id,
|
||||
|
||||
@@ -718,7 +718,6 @@ impl Database {
|
||||
.map(|language_server| proto::LanguageServer {
|
||||
id: language_server.id as u64,
|
||||
name: language_server.name,
|
||||
worktree_id: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod access_token;
|
||||
pub mod billing_customer;
|
||||
pub mod billing_preference;
|
||||
pub mod billing_subscription;
|
||||
pub mod buffer;
|
||||
pub mod buffer_operation;
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
use crate::db::{BillingPreferencesId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_preferences")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingPreferencesId,
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
mod cents;
|
||||
pub mod clickhouse;
|
||||
pub mod db;
|
||||
pub mod env;
|
||||
@@ -10,7 +9,6 @@ pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -22,17 +20,13 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
pub use cents::*;
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
@@ -180,6 +174,7 @@ pub struct Config {
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub stripe_price_id: Option<Arc<str>>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -198,10 +193,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_llm_billing_enabled(&self) -> bool {
|
||||
self.stripe_api_key.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
@@ -240,6 +231,7 @@ impl Config {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
}
|
||||
@@ -272,11 +264,9 @@ impl ServiceMode {
|
||||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub clickhouse_client: Option<::clickhouse::Client>,
|
||||
@@ -290,20 +280,6 @@ impl AppState {
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
.llm_database_url
|
||||
.clone()
|
||||
.zip(config.llm_database_max_connections)
|
||||
{
|
||||
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||
llm_db_options.max_connections(llm_database_max_connections);
|
||||
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||
llm_db.initialize().await?;
|
||||
Some(Arc::new(llm_db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let live_kit_client = if let Some(((server, key), secret)) = config
|
||||
.live_kit_server
|
||||
.as_ref()
|
||||
@@ -320,16 +296,11 @@ impl AppState {
|
||||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
live_kit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
clickhouse_client: config
|
||||
@@ -342,11 +313,12 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
let api_key = config
|
||||
.stripe_api_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
|
||||
|
||||
Ok(stripe::Client::new(api_key))
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ mod telemetry;
|
||||
mod token;
|
||||
|
||||
use crate::{
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
|
||||
Config, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
@@ -20,14 +20,14 @@ use axum::{
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use collections::HashMap;
|
||||
use db::TokenUsage;
|
||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::ListModelsResponse;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
@@ -318,31 +318,22 @@ async fn perform_completion(
|
||||
chunks
|
||||
.map(move |event| {
|
||||
let chunk = event?;
|
||||
let (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
) = match &chunk {
|
||||
let (input_tokens, output_tokens) = match &chunk {
|
||||
anthropic::Event::MessageStart {
|
||||
message: anthropic::Response { usage, .. },
|
||||
}
|
||||
| anthropic::Event::MessageDelta { usage, .. } => (
|
||||
usage.input_tokens.unwrap_or(0) as usize,
|
||||
usage.output_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_creation_input_tokens.unwrap_or(0) as usize,
|
||||
usage.cache_read_input_tokens.unwrap_or(0) as usize,
|
||||
),
|
||||
_ => (0, 0, 0, 0),
|
||||
_ => (0, 0),
|
||||
};
|
||||
|
||||
anyhow::Ok(CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
anyhow::Ok((
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
})
|
||||
))
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
@@ -368,13 +359,11 @@ async fn perform_completion(
|
||||
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||
let output_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
(
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
)
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
@@ -398,13 +387,13 @@ async fn perform_completion(
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
// TODO - implement token counting for Google AI
|
||||
CompletionChunk {
|
||||
bytes: serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
}
|
||||
let input_tokens = 0;
|
||||
let output_tokens = 0;
|
||||
(
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
@@ -416,7 +405,8 @@ async fn perform_completion(
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
tokens: TokenUsage::default(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
@@ -433,18 +423,10 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// The maximum monthly spending an individual user can reach on the free tier
|
||||
/// before they have to pay.
|
||||
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5);
|
||||
|
||||
/// The default value to use for maximum spend per month if the user did not
|
||||
/// explicitly set a maximum spend.
|
||||
///
|
||||
/// Used to prevent surprise bills.
|
||||
pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
|
||||
|
||||
/// The maximum lifetime spending an individual user can reach before being cut off.
|
||||
const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);
|
||||
///
|
||||
/// Represented in cents.
|
||||
const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100;
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
@@ -463,32 +445,7 @@ async fn check_usage_limit(
|
||||
)
|
||||
.await?;
|
||||
|
||||
if state.config.is_llm_billing_enabled() {
|
||||
if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT {
|
||||
if !claims.has_llm_subscription {
|
||||
return Err(Error::http(
|
||||
StatusCode::PAYMENT_REQUIRED,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
|
||||
return Err(Error::Http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached for this month.".to_string(),
|
||||
[(
|
||||
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||
HeaderValue::from_static("true"),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Remove this once we've rolled out monthly spending limits.
|
||||
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT {
|
||||
if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS {
|
||||
return Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"Maximum spending limit reached.".to_string(),
|
||||
@@ -535,6 +492,7 @@ async fn check_usage_limit(
|
||||
UsageMeasure::RequestsPerMinute => "requests_per_minute",
|
||||
UsageMeasure::TokensPerMinute => "tokens_per_minute",
|
||||
UsageMeasure::TokensPerDay => "tokens_per_day",
|
||||
_ => "",
|
||||
};
|
||||
|
||||
if let Some(client) = state.clickhouse_client.as_ref() {
|
||||
@@ -593,38 +551,29 @@ async fn check_usage_limit(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CompletionChunk {
|
||||
bytes: Vec<u8>,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
cache_creation_input_tokens: usize,
|
||||
cache_read_input_tokens: usize,
|
||||
}
|
||||
|
||||
struct TokenCountingStream<S> {
|
||||
state: Arc<LlmState>,
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
tokens: TokenUsage,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
impl<S> Stream for TokenCountingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<CompletionChunk, anyhow::Error>> + Unpin,
|
||||
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
|
||||
{
|
||||
type Item = Result<Vec<u8>, anyhow::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||
chunk.bytes.push(b'\n');
|
||||
self.tokens.input += chunk.input_tokens;
|
||||
self.tokens.output += chunk.output_tokens;
|
||||
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
|
||||
bytes.push(b'\n');
|
||||
self.input_tokens += input_tokens;
|
||||
self.output_tokens += output_tokens;
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
@@ -636,11 +585,11 @@ where
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let is_llm_billing_enabled = state.config.is_llm_billing_enabled();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let tokens = self.tokens;
|
||||
let input_token_count = self.input_tokens;
|
||||
let output_token_count = self.output_tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
@@ -649,16 +598,8 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
tokens,
|
||||
// We're passing `false` here if LLM billing is not enabled
|
||||
// so that we don't write any records to the
|
||||
// `billing_events` table until we're ready to bill users.
|
||||
if is_llm_billing_enabled {
|
||||
claims.has_llm_subscription
|
||||
} else {
|
||||
false
|
||||
},
|
||||
Cents(claims.max_monthly_spend_in_cents),
|
||||
input_token_count,
|
||||
output_token_count,
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
@@ -688,25 +629,15 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||
},
|
||||
model,
|
||||
provider: provider.to_string(),
|
||||
input_token_count: tokens.input as u64,
|
||||
cache_creation_input_token_count: tokens.input_cache_creation as u64,
|
||||
cache_read_input_token_count: tokens.input_cache_read as u64,
|
||||
output_token_count: tokens.output as u64,
|
||||
input_token_count: input_token_count as u64,
|
||||
output_token_count: output_token_count as u64,
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
input_tokens_this_month: usage.tokens_this_month.input as u64,
|
||||
cache_creation_input_tokens_this_month: usage
|
||||
.tokens_this_month
|
||||
.input_cache_creation
|
||||
as u64,
|
||||
cache_read_input_tokens_this_month: usage
|
||||
.tokens_this_month
|
||||
.input_cache_read
|
||||
as u64,
|
||||
output_tokens_this_month: usage.tokens_this_month.output as u64,
|
||||
spending_this_month: usage.spending_this_month.0 as u64,
|
||||
lifetime_spending: usage.lifetime_spending.0 as u64,
|
||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
||||
spending_this_month: usage.spending_this_month as u64,
|
||||
lifetime_spending: usage.lifetime_spending as u64,
|
||||
},
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -20,7 +20,7 @@ use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||
pub use queries::usages::ActiveUserCount;
|
||||
use sea_orm::prelude::*;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::{
|
||||
@@ -97,14 +97,6 @@ impl LlmDatabase {
|
||||
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
|
||||
}
|
||||
|
||||
pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.values()
|
||||
.find(|model| model.id == id)
|
||||
.ok_or_else(|| anyhow!("no model for ID {id:?}"))?)
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
@@ -3,9 +3,8 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(BillingEventId);
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use super::*;
|
||||
use crate::Result;
|
||||
use anyhow::Context as _;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let events_with_models = billing_event::Entity::find()
|
||||
.find_also_related(model::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
events_with_models
|
||||
.into_iter()
|
||||
.map(|(event, model)| {
|
||||
let model =
|
||||
model.context("could not find model associated with billing event")?;
|
||||
Ok((event, model))
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::llm::Cents;
|
||||
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use chrono::{Datelike, Duration};
|
||||
use crate::db::UserId;
|
||||
use chrono::Duration;
|
||||
use futures::StreamExt as _;
|
||||
use rpc::LanguageModelProvider;
|
||||
use sea_orm::QuerySelect;
|
||||
@@ -9,28 +8,15 @@ use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Default)]
|
||||
pub struct TokenUsage {
|
||||
pub input: usize,
|
||||
pub input_cache_creation: usize,
|
||||
pub input_cache_read: usize,
|
||||
pub output: usize,
|
||||
}
|
||||
|
||||
impl TokenUsage {
|
||||
pub fn total(&self) -> usize {
|
||||
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub tokens_this_month: TokenUsage,
|
||||
pub spending_this_month: Cents,
|
||||
pub lifetime_spending: Cents,
|
||||
pub input_tokens_this_month: usize,
|
||||
pub output_tokens_this_month: usize,
|
||||
pub spending_this_month: usize,
|
||||
pub lifetime_spending: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
@@ -152,46 +138,6 @@ impl LlmDatabase {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_user_spending_for_month(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Cents> {
|
||||
self.transaction(|tx| async move {
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
|
||||
let mut monthly_usages = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
let mut monthly_spending = Cents::ZERO;
|
||||
|
||||
while let Some(usage) = monthly_usages.next().await {
|
||||
let usage = usage?;
|
||||
let Ok(model) = self.model_by_id(usage.model_id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
monthly_spending += calculate_spending(
|
||||
model,
|
||||
usage.input_tokens as usize,
|
||||
usage.cache_creation_input_tokens as usize,
|
||||
usage.cache_read_input_tokens as usize,
|
||||
usage.output_tokens as usize,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(monthly_spending)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
@@ -214,26 +160,17 @@ impl LlmDatabase {
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
.await?
|
||||
.map_or((0, 0), |usage| {
|
||||
(usage.input_tokens as usize, usage.output_tokens as usize)
|
||||
});
|
||||
|
||||
let requests_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
|
||||
@@ -241,47 +178,21 @@ impl LlmDatabase {
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
|
||||
let tokens_this_day =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
|
||||
let spending_this_month = if let Some(monthly_usage) = &monthly_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage {
|
||||
calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
)
|
||||
} else {
|
||||
Cents::ZERO
|
||||
};
|
||||
let input_tokens_this_month =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?;
|
||||
let output_tokens_this_month =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
|
||||
let spending_this_month =
|
||||
calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
|
||||
let lifetime_spending =
|
||||
calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens);
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.input_tokens as usize),
|
||||
input_cache_creation: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||
input_cache_read: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||
output: monthly_usage
|
||||
.as_ref()
|
||||
.map_or(0, |usage| usage.output_tokens as usize),
|
||||
},
|
||||
input_tokens_this_month,
|
||||
output_tokens_this_month,
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -296,9 +207,8 @@ impl LlmDatabase {
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
tokens: TokenUsage,
|
||||
has_llm_subscription: bool,
|
||||
max_monthly_spend: Cents,
|
||||
input_token_count: usize,
|
||||
output_token_count: usize,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -333,7 +243,7 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
tokens.total(),
|
||||
input_token_count + output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
@@ -345,90 +255,36 @@ impl LlmDatabase {
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
tokens.total(),
|
||||
input_token_count + output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let month = now.date_naive().month() as i32;
|
||||
let year = now.date_naive().year();
|
||||
|
||||
// Update monthly usage
|
||||
let monthly_usage = monthly_usage::Entity::find()
|
||||
.filter(
|
||||
monthly_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(monthly_usage::Column::ModelId.eq(model.id))
|
||||
.and(monthly_usage::Column::Month.eq(month))
|
||||
.and(monthly_usage::Column::Year.eq(year)),
|
||||
let input_tokens_this_month = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::InputTokensPerMonth,
|
||||
now,
|
||||
input_token_count,
|
||||
&tx,
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let monthly_usage = match monthly_usage {
|
||||
Some(usage) => {
|
||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
monthly_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
month: ActiveValue::set(month),
|
||||
year: ActiveValue::set(year),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let spending_this_month = calculate_spending(
|
||||
model,
|
||||
monthly_usage.input_tokens as usize,
|
||||
monthly_usage.cache_creation_input_tokens as usize,
|
||||
monthly_usage.cache_read_input_tokens as usize,
|
||||
monthly_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
if !is_staff
|
||||
&& spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
|
||||
&& has_llm_subscription
|
||||
&& spending_this_month <= max_monthly_spend
|
||||
{
|
||||
billing_event::ActiveModel {
|
||||
id: ActiveValue::not_set(),
|
||||
idempotency_key: ActiveValue::not_set(),
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
input_cache_creation_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
}
|
||||
.insert(&*tx)
|
||||
let output_tokens_this_month = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::OutputTokensPerMonth,
|
||||
now,
|
||||
output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
let spending_this_month =
|
||||
calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
@@ -444,14 +300,12 @@ impl LlmDatabase {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(
|
||||
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
@@ -461,12 +315,8 @@ impl LlmDatabase {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||
cache_creation_input_tokens: ActiveValue::set(
|
||||
tokens.input_cache_creation as i64,
|
||||
),
|
||||
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
@@ -477,8 +327,6 @@ impl LlmDatabase {
|
||||
let lifetime_spending = calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.cache_creation_input_tokens as usize,
|
||||
lifetime_usage.cache_read_input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
@@ -486,12 +334,8 @@ impl LlmDatabase {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: monthly_usage.input_tokens as usize,
|
||||
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||
output: monthly_usage.output_tokens as usize,
|
||||
},
|
||||
input_tokens_this_month,
|
||||
output_tokens_this_month,
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
@@ -657,29 +501,18 @@ impl LlmDatabase {
|
||||
fn calculate_spending(
|
||||
model: &model::Model,
|
||||
input_tokens_this_month: usize,
|
||||
cache_creation_input_tokens_this_month: usize,
|
||||
cache_read_input_tokens_this_month: usize,
|
||||
output_tokens_this_month: usize,
|
||||
) -> Cents {
|
||||
) -> usize {
|
||||
let input_token_cost =
|
||||
input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000;
|
||||
let cache_creation_input_token_cost = cache_creation_input_tokens_this_month
|
||||
* model.price_per_million_cache_creation_input_tokens as usize
|
||||
/ 1_000_000;
|
||||
let cache_read_input_token_cost = cache_read_input_tokens_this_month
|
||||
* model.price_per_million_cache_read_input_tokens as usize
|
||||
/ 1_000_000;
|
||||
let output_token_cost =
|
||||
output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000;
|
||||
let spending = input_token_cost
|
||||
+ cache_creation_input_token_cost
|
||||
+ cache_read_input_token_cost
|
||||
+ output_token_cost;
|
||||
Cents::new(spending as u32)
|
||||
input_token_cost + output_token_cost
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
const DAY_BUCKET_COUNT: usize = 48;
|
||||
const MONTH_BUCKET_COUNT: usize = 30;
|
||||
|
||||
impl UsageMeasure {
|
||||
fn bucket_count(&self) -> usize {
|
||||
@@ -687,6 +520,8 @@ impl UsageMeasure {
|
||||
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
|
||||
UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT,
|
||||
UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -695,6 +530,8 @@ impl UsageMeasure {
|
||||
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerDay => Duration::hours(24),
|
||||
UsageMeasure::InputTokensPerMonth => Duration::days(30),
|
||||
UsageMeasure::OutputTokensPerMonth => Duration::days(30),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
pub mod billing_event;
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
pub mod revoked_access_token;
|
||||
pub mod usage;
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{BillingEventId, ModelId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingEventId,
|
||||
pub idempotency_key: Uuid,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub input_cache_creation_tokens: i64,
|
||||
pub input_cache_read_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -9,8 +9,6 @@ pub struct Model {
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub cache_creation_input_tokens: i64,
|
||||
pub cache_read_input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ pub struct Model {
|
||||
pub max_tokens_per_minute: i64,
|
||||
pub max_tokens_per_day: i64,
|
||||
pub price_per_million_input_tokens: i32,
|
||||
pub price_per_million_cache_creation_input_tokens: i32,
|
||||
pub price_per_million_cache_read_input_tokens: i32,
|
||||
pub price_per_million_output_tokens: i32,
|
||||
}
|
||||
|
||||
@@ -29,8 +27,6 @@ pub enum Relation {
|
||||
Provider,
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
#[sea_orm(has_many = "super::billing_event::Entity")]
|
||||
BillingEvents,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
@@ -45,10 +41,4 @@ impl Related<super::usage::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_event::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingEvents.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
use crate::{db::UserId, llm::db::ModelId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "monthly_usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub month: i32,
|
||||
pub year: i32,
|
||||
pub input_tokens: i64,
|
||||
pub cache_creation_input_tokens: i64,
|
||||
pub cache_read_input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -9,6 +9,8 @@ pub enum UsageMeasure {
|
||||
RequestsPerMinute,
|
||||
TokensPerMinute,
|
||||
TokensPerDay,
|
||||
InputTokensPerMonth,
|
||||
OutputTokensPerMonth,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
mod billing_tests;
|
||||
mod provider_tests;
|
||||
mod usage_tests;
|
||||
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::{
|
||||
db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase, TokenUsage,
|
||||
},
|
||||
FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
},
|
||||
test_llm_db, Cents,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(
|
||||
test_billing_limit_exceeded,
|
||||
test_billing_limit_exceeded_postgres
|
||||
);
|
||||
|
||||
async fn test_billing_limit_exceeded(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "fake-claude-limerick";
|
||||
const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5;
|
||||
const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5;
|
||||
|
||||
// Initialize the database and insert the model
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS,
|
||||
price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Set a fixed datetime for consistent testing
|
||||
let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let max_monthly_spend = Cents::from_dollars(10);
|
||||
|
||||
// Record usage that brings us close to the limit but doesn't exceed it
|
||||
// Let's say we use $9.50 worth of tokens
|
||||
let tokens_to_use = 190_000_000; // This will cost $9.50 at $0.05 per 1 million tokens
|
||||
let usage = TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
let cost = Cents::new(tokens_to_use as u32 / 1_000_000 * PRICE_PER_MILLION_INPUT_TOKENS as u32);
|
||||
|
||||
assert_eq!(
|
||||
cost,
|
||||
Cents::new(950),
|
||||
"expected the cost to be $9.50, based on the inputs, but it wasn't"
|
||||
);
|
||||
|
||||
// Verify that before we record any usage, there are 0 billing events
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 0);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the recorded usage and spending
|
||||
let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
|
||||
// Verify that we exceeded the free tier usage
|
||||
assert!(
|
||||
recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT,
|
||||
"Expected spending to exceed free tier limit"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
recorded_usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: tokens_to_use,
|
||||
tokens_this_day: tokens_to_use,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: tokens_to_use,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::new(950),
|
||||
lifetime_spending: Cents::new(950),
|
||||
}
|
||||
);
|
||||
|
||||
// Verify that there is one `billing_event` record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
let (billing_event, _model) = &billing_events[0];
|
||||
assert_eq!(billing_event.user_id, user_id);
|
||||
assert_eq!(billing_event.input_tokens, tokens_to_use as i64);
|
||||
assert_eq!(billing_event.input_cache_creation_tokens, 0);
|
||||
assert_eq!(billing_event.input_cache_read_tokens, 0);
|
||||
assert_eq!(billing_event.output_tokens, 0);
|
||||
|
||||
let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $9.50 to $10.50, which is over the $10 monthly maximum limit
|
||||
let usage_exceeding = TokenUsage {
|
||||
input: tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
};
|
||||
|
||||
// This should still create a billing event as it's the first request that exceeds the limit
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
usage_exceeding,
|
||||
true,
|
||||
max_monthly_spend,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify that there is still one billing record
|
||||
let billing_events = db.get_billing_events().await.unwrap();
|
||||
assert_eq!(billing_events.len(), 1);
|
||||
|
||||
// Verify the updated usage and spending
|
||||
let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
updated_usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: tokens_to_use + tokens_to_exceed,
|
||||
tokens_this_day: tokens_to_use + tokens_to_exceed,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: tokens_to_use + tokens_to_exceed,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::new(1050),
|
||||
lifetime_spending: Cents::new(1050),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -2,11 +2,11 @@ use crate::{
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase, TokenUsage,
|
||||
LlmDatabase,
|
||||
},
|
||||
test_llm_db, Cents,
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use chrono::{Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
@@ -29,49 +29,18 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
let t0 = Utc::now();
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 2000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -80,14 +49,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 3000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -99,35 +64,17 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
tokens_this_day: 3000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 3000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 3000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 3000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -136,14 +83,10 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
tokens_this_day: 6000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 6000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
@@ -156,34 +99,16 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 6000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 6000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 4000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.record_usage(user_id, false, provider, model, 4000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
@@ -192,93 +117,26 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
tokens_this_day: 9000,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 10000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
input_tokens_this_month: 10000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
// We're using a fixed datetime to prevent flakiness based on the clock.
|
||||
let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z")
|
||||
.unwrap()
|
||||
.with_timezone(&Utc);
|
||||
|
||||
// Test cache creation input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let t2 = t0 + Duration::days(30);
|
||||
let now = t2;
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 1500,
|
||||
tokens_this_day: 1500,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 0,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
}
|
||||
);
|
||||
|
||||
// Test cache read input tokens
|
||||
db.record_usage(
|
||||
user_id,
|
||||
false,
|
||||
provider,
|
||||
model,
|
||||
TokenUsage {
|
||||
input: 1000,
|
||||
input_cache_creation: 0,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
false,
|
||||
Cents::ZERO,
|
||||
now,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 2800,
|
||||
tokens_this_day: 2800,
|
||||
tokens_this_month: TokenUsage {
|
||||
input: 2000,
|
||||
input_cache_creation: 500,
|
||||
input_cache_read: 300,
|
||||
output: 0,
|
||||
},
|
||||
spending_this_month: Cents::ZERO,
|
||||
lifetime_spending: Cents::ZERO,
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 0,
|
||||
input_tokens_this_month: 9000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -12,15 +12,11 @@ pub struct LlmUsageEventRow {
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub input_token_count: u64,
|
||||
pub cache_creation_input_token_count: u64,
|
||||
pub cache_read_input_token_count: u64,
|
||||
pub output_token_count: u64,
|
||||
pub requests_this_minute: u64,
|
||||
pub tokens_this_minute: u64,
|
||||
pub tokens_this_day: u64,
|
||||
pub input_tokens_this_month: u64,
|
||||
pub cache_creation_input_tokens_this_month: u64,
|
||||
pub cache_read_input_tokens_this_month: u64,
|
||||
pub output_tokens_this_month: u64,
|
||||
pub spending_this_month: u64,
|
||||
pub lifetime_spending: u64,
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::{
|
||||
db::{billing_preference, UserId},
|
||||
Config,
|
||||
};
|
||||
use crate::{db::UserId, Config};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
@@ -17,25 +13,26 @@ pub struct LlmTokenClaims {
|
||||
pub exp: u64,
|
||||
pub jti: String,
|
||||
pub user_id: u64,
|
||||
pub github_user_login: String,
|
||||
// This field is temporarily optional so it can be added
|
||||
// in a backwards-compatible way. We can make it required
|
||||
// once all of the LLM tokens have cycled (~1 hour after
|
||||
// this change has been deployed).
|
||||
#[serde(default)]
|
||||
pub github_user_login: Option<String>,
|
||||
pub is_staff: bool,
|
||||
#[serde(default)]
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub has_llm_subscription: bool,
|
||||
pub max_monthly_spend_in_cents: u32,
|
||||
pub plan: rpc::proto::Plan,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
impl LlmTokenClaims {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn create(
|
||||
user_id: UserId,
|
||||
github_user_login: String,
|
||||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
has_llm_closed_beta_feature_flag: bool,
|
||||
has_llm_subscription: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
@@ -50,14 +47,9 @@ impl LlmTokenClaims {
|
||||
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
|
||||
jti: uuid::Uuid::new_v4().to_string(),
|
||||
user_id: user_id.to_proto(),
|
||||
github_user_login,
|
||||
github_user_login: Some(github_user_login),
|
||||
is_staff,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
has_llm_subscription,
|
||||
max_monthly_spend_in_cents: billing_preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents as u32
|
||||
}),
|
||||
plan,
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ use axum::{
|
||||
routing::get,
|
||||
Extension, Router,
|
||||
};
|
||||
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::llm::{db::LlmDatabase, log_usage_periodically};
|
||||
use collab::migrations::run_database_migrations;
|
||||
@@ -30,7 +29,7 @@ use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{
|
||||
filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
|
||||
};
|
||||
use util::{maybe, ResultExt as _};
|
||||
use util::ResultExt as _;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
|
||||
@@ -111,13 +110,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
if let Some(stripe_billing) = state.stripe_billing.clone() {
|
||||
let executor = state.executor.clone();
|
||||
executor.spawn_detached(async move {
|
||||
stripe_billing.initialize().await.trace_err();
|
||||
});
|
||||
}
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
@@ -144,29 +136,6 @@ async fn main() -> Result<()> {
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
let llm_db = maybe!(async {
|
||||
let database_url = state
|
||||
.config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = state
|
||||
.config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
LlmDatabase::new(db_options, state.executor.clone()).await
|
||||
})
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_usage_with_stripe_periodically(state.clone());
|
||||
}
|
||||
|
||||
app = app
|
||||
.merge(collab::api::events::router())
|
||||
.merge(collab::api::extensions::router())
|
||||
|
||||
@@ -191,26 +191,16 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn has_llm_subscription(
|
||||
&self,
|
||||
db: &MutexGuard<'_, DbHandle>,
|
||||
) -> anyhow::Result<bool> {
|
||||
pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
return Ok(true);
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let Some(user_id) = self.user_id() else {
|
||||
return Ok(false);
|
||||
return Ok(proto::Plan::Free);
|
||||
};
|
||||
|
||||
Ok(db.has_active_billing_subscription(user_id).await?)
|
||||
}
|
||||
|
||||
pub async fn current_plan(
|
||||
&self,
|
||||
_db: &MutexGuard<'_, DbHandle>,
|
||||
) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
if db.has_active_billing_subscription(user_id).await? {
|
||||
Ok(proto::Plan::ZedPro)
|
||||
} else {
|
||||
Ok(proto::Plan::Free)
|
||||
@@ -469,6 +459,9 @@ impl Server {
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskContextForLocation>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_project_request_for_owner::<proto::TaskTemplates>,
|
||||
))
|
||||
.add_request_handler(user_handler(
|
||||
forward_read_only_project_request::<proto::GetHover>,
|
||||
))
|
||||
@@ -1218,15 +1211,6 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, proto::RefreshLlmToken {})
|
||||
.trace_err();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
||||
ServerSnapshot {
|
||||
connection_pool: ConnectionPoolGuard {
|
||||
@@ -3487,7 +3471,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
||||
}
|
||||
|
||||
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
|
||||
let plan = session.current_plan(&session.db().await).await?;
|
||||
let plan = session.current_plan(session.db().await).await?;
|
||||
|
||||
session
|
||||
.peer
|
||||
@@ -4487,7 +4471,7 @@ async fn count_language_model_tokens(
|
||||
};
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
|
||||
};
|
||||
@@ -4608,7 +4592,7 @@ async fn compute_embeddings(
|
||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
|
||||
};
|
||||
@@ -4926,17 +4910,12 @@ async fn get_llm_api_token(
|
||||
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
user.id,
|
||||
user.github_login.clone(),
|
||||
session.is_staff(),
|
||||
billing_preferences,
|
||||
has_llm_closed_beta_feature_flag,
|
||||
session.has_llm_subscription(&db).await?,
|
||||
session.current_plan(&db).await?,
|
||||
session.current_plan(db).await?,
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
response.send(proto::GetLlmTokenResponse { token })?;
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{llm, Cents, Result};
|
||||
use anyhow::Context;
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<stripe::Client>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
}
|
||||
|
||||
pub struct StripeModel {
|
||||
input_tokens_price: StripeBillingPrice,
|
||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||
input_cache_read_tokens_price: StripeBillingPrice,
|
||||
output_tokens_price: StripeBillingPrice,
|
||||
}
|
||||
|
||||
struct StripeBillingPrice {
|
||||
id: stripe::PriceId,
|
||||
meter_event_name: String,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
log::info!("StripeBilling: initializing");
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let (meters, prices) = futures::try_join!(
|
||||
StripeMeter::list(&self.client),
|
||||
stripe::Price::list(&self.client, &stripe::ListPrices::default())
|
||||
)?;
|
||||
|
||||
for meter in meters.data {
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
|
||||
for price in prices.data {
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("StripeBilling: initialized");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
&format!("{} (Input Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_creation_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_read_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/output_tokens", model.id),
|
||||
&format!("{} (Output Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_output_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
Ok(StripeModel {
|
||||
input_tokens_price,
|
||||
input_cache_creation_tokens_price,
|
||||
input_cache_read_tokens_price,
|
||||
output_tokens_price,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_insert_price(
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
price_per_million_tokens: Cents,
|
||||
) -> Result<StripeBillingPrice> {
|
||||
// Fast code path when the meter and the price already exist.
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
return Ok(StripeBillingPrice {
|
||||
id: price_id.clone(),
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
|
||||
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
active: Some(true),
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.price_ids_by_meter_id
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
|
||||
Ok(StripeBillingPrice {
|
||||
id: price_id,
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_model(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
model: &StripeModel,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
|
||||
{
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_read_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.output_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !items.is_empty() {
|
||||
items.extend(subscription.items.data.iter().map(|item| {
|
||||
stripe::UpdateSubscriptionItems {
|
||||
id: Some(item.id.to_string()),
|
||||
..Default::default()
|
||||
}
|
||||
}));
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(items),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
model: &StripeModel,
|
||||
event: &llm::db::billing_event::Model,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
|
||||
if event.input_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_creation_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_creation_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_read_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_read_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_read_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.output_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("output_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.output_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.output_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
model: &StripeModel,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(
|
||||
[
|
||||
&model.input_tokens_price.id,
|
||||
&model.input_cache_creation_tokens_price.id,
|
||||
&model.input_cache_read_tokens_price.id,
|
||||
&model.output_tokens_price.id,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(price_id.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DefaultAggregation {
|
||||
formula: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterParams<'a> {
|
||||
default_aggregation: DefaultAggregation,
|
||||
display_name: String,
|
||||
event_name: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct StripeMeter {
|
||||
id: String,
|
||||
event_name: String,
|
||||
}
|
||||
|
||||
impl StripeMeter {
|
||||
pub fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterParams,
|
||||
) -> stripe::Response<Self> {
|
||||
client.post_form("/billing/meters", params)
|
||||
}
|
||||
|
||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {}
|
||||
|
||||
client.get_query("/billing/meters", Params {})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StripeMeterEvent {
|
||||
identifier: String,
|
||||
}
|
||||
|
||||
impl StripeMeterEvent {
|
||||
pub async fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterEventParams<'_>,
|
||||
) -> Result<Self, stripe::StripeError> {
|
||||
let identifier = params.identifier;
|
||||
match client.post_form("/billing/meter_events", params).await {
|
||||
Ok(event) => Ok(event),
|
||||
Err(stripe::StripeError::Stripe(error)) => {
|
||||
if error.http_status == 400
|
||||
&& error
|
||||
.message
|
||||
.as_ref()
|
||||
.map_or(false, |message| message.contains(identifier))
|
||||
{
|
||||
Ok(Self {
|
||||
identifier: identifier.to_string(),
|
||||
})
|
||||
} else {
|
||||
Err(stripe::StripeError::Stripe(error))
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventParams<'a> {
|
||||
identifier: &'a str,
|
||||
event_name: &'a str,
|
||||
payload: StripeCreateMeterEventPayload<'a>,
|
||||
timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterEventPayload<'a> {
|
||||
value: u64,
|
||||
stripe_customer_id: &'a stripe::CustomerId,
|
||||
}
|
||||
|
||||
fn subscription_contains_price(
|
||||
subscription: &stripe::Subscription,
|
||||
price_id: &stripe::PriceId,
|
||||
) -> bool {
|
||||
subscription.items.data.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id == *price_id)
|
||||
})
|
||||
}
|
||||
@@ -50,7 +50,7 @@ async fn test_channel_guests(
|
||||
project_b.read_with(cx_b, |project, _| project.remote_id()),
|
||||
Some(project_id),
|
||||
);
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(project_b
|
||||
.update(cx_b, |project, cx| {
|
||||
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
|
||||
@@ -103,7 +103,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
workspace.active_item_as::<Editor>(cx).unwrap(),
|
||||
)
|
||||
});
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |e, cx| e.read_only(cx)));
|
||||
assert!(room_b.read_with(cx_b, |room, _| !room.can_use_microphone()));
|
||||
assert!(room_b
|
||||
@@ -127,7 +127,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are now editable
|
||||
assert!(project_b.read_with(cx_b, |project, cx| !project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| !project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| !editor.read_only(cx)));
|
||||
|
||||
// B sees themselves as muted, and can unmute.
|
||||
@@ -153,7 +153,7 @@ async fn test_channel_guest_promotion(cx_a: &mut TestAppContext, cx_b: &mut Test
|
||||
cx_a.run_until_parked();
|
||||
|
||||
// project and buffers are no longer editable
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_read_only(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
|
||||
assert!(editor_b.update(cx_b, |editor, cx| editor.read_only(cx)));
|
||||
assert!(room_b
|
||||
.update(cx_b, |room, cx| room.share_microphone(cx))
|
||||
|
||||
@@ -262,7 +262,7 @@ async fn test_dev_server_leave_room(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -308,7 +308,7 @@ async fn test_dev_server_delete(
|
||||
cx1.executor().run_until_parked();
|
||||
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
|
||||
cx1.update(|cx| {
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
@@ -418,12 +418,12 @@ async fn test_dev_server_refresh_access_token(
|
||||
|
||||
// Assert that the other client was disconnected
|
||||
let (workspace, cx2) = client2.active_workspace(cx2);
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected(cx)));
|
||||
cx2.update(|cx| assert!(workspace.read(cx).project().read(cx).is_disconnected()));
|
||||
|
||||
// Assert that the owner of the dev server does not see the dev server as online anymore
|
||||
let (workspace, cx1) = client1.active_workspace(cx1);
|
||||
cx1.update(|cx| {
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected(cx));
|
||||
assert!(workspace.read(cx).project().read(cx).is_disconnected());
|
||||
dev_server_projects::Store::global(cx).update(cx, |store, _| {
|
||||
assert_eq!(
|
||||
store.dev_servers().first().unwrap().status,
|
||||
|
||||
@@ -114,7 +114,7 @@ async fn test_host_disconnect(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_b.read_with(cx_b, |project, cx| project.is_read_only(cx));
|
||||
project_b.read_with(cx_b, |project, _| project.is_read_only());
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
|
||||
@@ -1389,7 +1389,7 @@ async fn test_unshare_project(
|
||||
.unwrap();
|
||||
executor.run_until_parked();
|
||||
|
||||
assert!(project_b.read_with(cx_b, |project, cx| project.is_disconnected(cx)));
|
||||
assert!(project_b.read_with(cx_b, |project, _| project.is_disconnected()));
|
||||
|
||||
// Client C opens the project.
|
||||
let project_c = client_c.join_remote_project(project_id, cx_c).await;
|
||||
@@ -1402,7 +1402,7 @@ async fn test_unshare_project(
|
||||
|
||||
assert!(worktree_a.read_with(cx_a, |tree, _| !tree.has_update_observer()));
|
||||
|
||||
assert!(project_c.read_with(cx_c, |project, cx| project.is_disconnected(cx)));
|
||||
assert!(project_c.read_with(cx_c, |project, _| project.is_disconnected()));
|
||||
|
||||
// Client C can open the project again after client A re-shares.
|
||||
let project_id = active_call_a
|
||||
@@ -1427,8 +1427,8 @@ async fn test_unshare_project(
|
||||
|
||||
project_a.read_with(cx_a, |project, _| assert!(!project.is_shared()));
|
||||
|
||||
project_c2.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_c2.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
assert!(project.collaborators().is_empty());
|
||||
});
|
||||
}
|
||||
@@ -1560,8 +1560,8 @@ async fn test_project_reconnect(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
project_b1.read_with(cx_b, |project, _| {
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
@@ -1661,7 +1661,7 @@ async fn test_project_reconnect(
|
||||
});
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1695,9 +1695,9 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
project_b2.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(!project.is_disconnected(cx)));
|
||||
project_b3.read_with(cx_b, |project, _| assert!(!project.is_disconnected()));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WaZ"));
|
||||
|
||||
@@ -1754,7 +1754,7 @@ async fn test_project_reconnect(
|
||||
executor.run_until_parked();
|
||||
|
||||
project_b1.read_with(cx_b, |project, cx| {
|
||||
assert!(!project.is_disconnected(cx));
|
||||
assert!(!project.is_disconnected());
|
||||
assert_eq!(
|
||||
project
|
||||
.worktree_for_id(worktree1_id, cx)
|
||||
@@ -1788,7 +1788,7 @@ async fn test_project_reconnect(
|
||||
);
|
||||
});
|
||||
|
||||
project_b3.read_with(cx_b, |project, cx| assert!(project.is_disconnected(cx)));
|
||||
project_b3.read_with(cx_b, |project, _| assert!(project.is_disconnected()));
|
||||
|
||||
buffer_a1.read_with(cx_a, |buffer, _| assert_eq!(buffer.text(), "WXaYZ"));
|
||||
|
||||
@@ -3816,8 +3816,8 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 1);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
@@ -3849,12 +3849,12 @@ async fn test_leaving_project(
|
||||
assert_eq!(project.collaborators().len(), 0);
|
||||
});
|
||||
|
||||
project_b2.read_with(cx_b, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_b2.read_with(cx_b, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
|
||||
project_c.read_with(cx_c, |project, cx| {
|
||||
assert!(project.is_disconnected(cx));
|
||||
project_c.read_with(cx_c, |project, _| {
|
||||
assert!(project.is_disconnected());
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1168,7 +1168,7 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
Some((project, cx))
|
||||
});
|
||||
|
||||
if !guest_project.is_disconnected(cx) {
|
||||
if !guest_project.is_disconnected() {
|
||||
if let Some((host_project, host_cx)) = host_project {
|
||||
let host_worktree_snapshots =
|
||||
host_project.read_with(host_cx, |host_project, cx| {
|
||||
@@ -1254,8 +1254,8 @@ impl RandomizedTest for ProjectCollaborationTest {
|
||||
|
||||
let buffers = client.buffers().clone();
|
||||
for (guest_project, guest_buffers) in &buffers {
|
||||
let project_id = if guest_project.read_with(client_cx, |project, cx| {
|
||||
project.is_local() || project.is_disconnected(cx)
|
||||
let project_id = if guest_project.read_with(client_cx, |project, _| {
|
||||
project.is_local() || project.is_disconnected()
|
||||
}) {
|
||||
continue;
|
||||
} else {
|
||||
|
||||
@@ -532,9 +532,9 @@ impl<T: RandomizedTest> TestPlan<T> {
|
||||
server.allow_connections();
|
||||
|
||||
for project in client.dev_server_projects().iter() {
|
||||
project.read_with(&client_cx, |project, cx| {
|
||||
project.read_with(&client_cx, |project, _| {
|
||||
assert!(
|
||||
project.is_disconnected(cx),
|
||||
project.is_disconnected(),
|
||||
"project {:?} should be read only",
|
||||
project.remote_id()
|
||||
)
|
||||
|
||||
@@ -2,12 +2,10 @@ use crate::tests::TestServer;
|
||||
use call::ActiveCall;
|
||||
use fs::{FakeFs, Fs as _};
|
||||
use gpui::{Context as _, TestAppContext};
|
||||
use http_client::BlockedHttpClient;
|
||||
use language::{language_settings::all_language_settings, LanguageRegistry};
|
||||
use node_runtime::NodeRuntime;
|
||||
use language::language_settings::all_language_settings;
|
||||
use project::ProjectPath;
|
||||
use remote::SshRemoteClient;
|
||||
use remote_server::{HeadlessAppState, HeadlessProject};
|
||||
use remote_server::HeadlessProject;
|
||||
use serde_json::json;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
@@ -50,22 +48,8 @@ async fn test_sharing_an_ssh_remote_project(
|
||||
|
||||
// User A connects to the remote project via SSH.
|
||||
server_cx.update(HeadlessProject::init);
|
||||
let remote_http_client = Arc::new(BlockedHttpClient);
|
||||
let node = NodeRuntime::unavailable();
|
||||
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
|
||||
let _headless_project = server_cx.new_model(|cx| {
|
||||
client::init_settings(cx);
|
||||
HeadlessProject::new(
|
||||
HeadlessAppState {
|
||||
session: server_ssh,
|
||||
fs: remote_fs.clone(),
|
||||
http_client: remote_http_client,
|
||||
node_runtime: node,
|
||||
languages,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let _headless_project =
|
||||
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
|
||||
|
||||
let (project_a, worktree_id) = client_a
|
||||
.build_ssh_project("/code/project1", client_ssh, cx_a)
|
||||
|
||||
@@ -635,11 +635,9 @@ impl TestServer {
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
llm_db: None,
|
||||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
stripe_billing: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
clickhouse_client: None,
|
||||
@@ -679,6 +677,7 @@ impl TestServer {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
},
|
||||
|
||||
@@ -111,7 +111,7 @@ impl MessageEditor {
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_completion_provider(Some(Box::new(MessageEditorCompletionProvider(this))));
|
||||
editor.set_completion_provider(Box::new(MessageEditorCompletionProvider(this)));
|
||||
editor.set_auto_replace_emoji_shortcode(
|
||||
MessageEditorSettings::get_global(cx)
|
||||
.auto_replace_emoji_shortcode
|
||||
|
||||
@@ -294,7 +294,6 @@ gpui::actions!(
|
||||
RevealInFileManager,
|
||||
ReverseLines,
|
||||
RevertFile,
|
||||
ReloadFile,
|
||||
RevertSelectedHunks,
|
||||
Rewrap,
|
||||
ScrollCursorBottom,
|
||||
|
||||
@@ -121,11 +121,10 @@ use multi_buffer::{
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use project::project_settings::{GitGutterSetting, ProjectSettings};
|
||||
use project::{
|
||||
lsp_store::FormatTrigger,
|
||||
project_settings::{GitGutterSetting, ProjectSettings},
|
||||
CodeAction, Completion, CompletionIntent, DocumentHighlight, InlayHint, Item, Location,
|
||||
LocationLink, Project, ProjectPath, ProjectTransaction, TaskSourceKind,
|
||||
lsp_store::FormatTrigger, CodeAction, Completion, CompletionIntent, Item, Location, Project,
|
||||
ProjectPath, ProjectTransaction, TaskSourceKind,
|
||||
};
|
||||
use rand::prelude::*;
|
||||
use rpc::{proto::*, ErrorExt};
|
||||
@@ -161,11 +160,11 @@ use ui::{
|
||||
};
|
||||
use util::{defer, maybe, post_inc, RangeExt, ResultExt, TryFutureExt};
|
||||
use workspace::item::{ItemHandle, PreviewTabsSettings};
|
||||
use workspace::notifications::{DetachAndPromptErr, NotificationId, NotifyTaskExt};
|
||||
use workspace::notifications::{DetachAndPromptErr, NotificationId};
|
||||
use workspace::{
|
||||
searchable::SearchEvent, ItemNavHistory, SplitDirection, ViewId, Workspace, WorkspaceId,
|
||||
};
|
||||
use workspace::{Item as WorkspaceItem, OpenInTerminal, OpenTerminal, TabBarSettings, Toast};
|
||||
use workspace::{OpenInTerminal, OpenTerminal, TabBarSettings, Toast};
|
||||
|
||||
use crate::hover_links::find_url;
|
||||
use crate::signature_help::{SignatureHelpHiddenBy, SignatureHelpState};
|
||||
@@ -547,7 +546,6 @@ pub struct Editor {
|
||||
active_diagnostics: Option<ActiveDiagnosticGroup>,
|
||||
soft_wrap_mode_override: Option<language_settings::SoftWrap>,
|
||||
project: Option<Model<Project>>,
|
||||
semantics_provider: Option<Rc<dyn SemanticsProvider>>,
|
||||
completion_provider: Option<Box<dyn CompletionProvider>>,
|
||||
collaboration_hub: Option<Box<dyn CollaborationHub>>,
|
||||
blink_manager: Model<BlinkManager>,
|
||||
@@ -886,12 +884,12 @@ enum ContextMenu {
|
||||
impl ContextMenu {
|
||||
fn select_first(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_first(provider, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_first(project, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_first(cx),
|
||||
}
|
||||
true
|
||||
@@ -902,12 +900,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_prev(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_prev(provider, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_prev(project, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_prev(cx),
|
||||
}
|
||||
true
|
||||
@@ -918,12 +916,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_next(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_next(provider, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_next(project, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_next(cx),
|
||||
}
|
||||
true
|
||||
@@ -934,12 +932,12 @@ impl ContextMenu {
|
||||
|
||||
fn select_last(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
if self.visible() {
|
||||
match self {
|
||||
ContextMenu::Completions(menu) => menu.select_last(provider, cx),
|
||||
ContextMenu::Completions(menu) => menu.select_last(project, cx),
|
||||
ContextMenu::CodeActions(menu) => menu.select_last(cx),
|
||||
}
|
||||
true
|
||||
@@ -993,55 +991,39 @@ struct CompletionsMenu {
|
||||
}
|
||||
|
||||
impl CompletionsMenu {
|
||||
fn select_first(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
fn select_first(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
self.selected_item = 0;
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_prev(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
fn select_prev(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
if self.selected_item > 0 {
|
||||
self.selected_item -= 1;
|
||||
} else {
|
||||
self.selected_item = self.matches.len() - 1;
|
||||
}
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_next(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
fn select_next(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
if self.selected_item + 1 < self.matches.len() {
|
||||
self.selected_item += 1;
|
||||
} else {
|
||||
self.selected_item = 0;
|
||||
}
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn select_last(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
fn select_last(&mut self, project: Option<&Model<Project>>, cx: &mut ViewContext<Editor>) {
|
||||
self.selected_item = self.matches.len() - 1;
|
||||
self.scroll_handle.scroll_to_item(self.selected_item);
|
||||
self.attempt_resolve_selected_completion_documentation(provider, cx);
|
||||
self.attempt_resolve_selected_completion_documentation(project, cx);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -1077,7 +1059,7 @@ impl CompletionsMenu {
|
||||
|
||||
fn attempt_resolve_selected_completion_documentation(
|
||||
&mut self,
|
||||
provider: Option<&dyn CompletionProvider>,
|
||||
project: Option<&Model<Project>>,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
let settings = EditorSettings::get_global(cx);
|
||||
@@ -1086,16 +1068,18 @@ impl CompletionsMenu {
|
||||
}
|
||||
|
||||
let completion_index = self.matches[self.selected_item].candidate_id;
|
||||
let Some(provider) = provider else {
|
||||
let Some(project) = project else {
|
||||
return;
|
||||
};
|
||||
|
||||
let resolve_task = provider.resolve_completions(
|
||||
self.buffer.clone(),
|
||||
vec![completion_index],
|
||||
self.completions.clone(),
|
||||
cx,
|
||||
);
|
||||
let resolve_task = project.update(cx, |project, cx| {
|
||||
project.resolve_completions(
|
||||
self.buffer.clone(),
|
||||
vec![completion_index],
|
||||
self.completions.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let delay_ms =
|
||||
EditorSettings::get_global(cx).completion_documentation_secondary_query_debounce;
|
||||
@@ -1687,7 +1671,7 @@ pub(crate) struct NavigationData {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum GotoDefinitionKind {
|
||||
enum GotoDefinitionKind {
|
||||
Symbol,
|
||||
Declaration,
|
||||
Type,
|
||||
@@ -1895,17 +1879,10 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
}));
|
||||
if let Some(task_inventory) = project
|
||||
.read(cx)
|
||||
.task_store()
|
||||
.read(cx)
|
||||
.task_inventory()
|
||||
.cloned()
|
||||
{
|
||||
project_subscriptions.push(cx.observe(&task_inventory, |editor, _, cx| {
|
||||
editor.tasks_update_task = Some(editor.refresh_runnables(cx));
|
||||
}));
|
||||
}
|
||||
let task_inventory = project.read(cx).task_inventory().clone();
|
||||
project_subscriptions.push(cx.observe(&task_inventory, |editor, _, cx| {
|
||||
editor.tasks_update_task = Some(editor.refresh_runnables(cx));
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1953,7 +1930,6 @@ impl Editor {
|
||||
active_diagnostics: None,
|
||||
soft_wrap_mode_override,
|
||||
completion_provider: project.clone().map(|project| Box::new(project) as _),
|
||||
semantics_provider: project.clone().map(|project| Rc::new(project) as _),
|
||||
collaboration_hub: project.clone().map(|project| Box::new(project) as _),
|
||||
project,
|
||||
blink_manager: blink_manager.clone(),
|
||||
@@ -2322,16 +2298,8 @@ impl Editor {
|
||||
self.custom_context_menu = Some(Box::new(f))
|
||||
}
|
||||
|
||||
pub fn set_completion_provider(&mut self, provider: Option<Box<dyn CompletionProvider>>) {
|
||||
self.completion_provider = provider;
|
||||
}
|
||||
|
||||
pub fn semantics_provider(&self) -> Option<Rc<dyn SemanticsProvider>> {
|
||||
self.semantics_provider.clone()
|
||||
}
|
||||
|
||||
pub fn set_semantics_provider(&mut self, provider: Option<Rc<dyn SemanticsProvider>>) {
|
||||
self.semantics_provider = provider;
|
||||
pub fn set_completion_provider(&mut self, provider: Box<dyn CompletionProvider>) {
|
||||
self.completion_provider = Some(provider);
|
||||
}
|
||||
|
||||
pub fn set_inline_completion_provider<T>(
|
||||
@@ -4066,7 +4034,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
fn refresh_inlay_hints(&mut self, reason: InlayHintRefreshReason, cx: &mut ViewContext<Self>) {
|
||||
if self.semantics_provider.is_none() || self.mode != EditorMode::Full {
|
||||
if self.project.is_none() || self.mode != EditorMode::Full {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4720,13 +4688,11 @@ impl Editor {
|
||||
);
|
||||
}
|
||||
project.update(cx, |project, cx| {
|
||||
project.task_store().update(cx, |task_store, cx| {
|
||||
task_store.task_context_for_location(
|
||||
captured_task_variables,
|
||||
location,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
project.task_context_for_location(
|
||||
captured_task_variables,
|
||||
location,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
@@ -4938,11 +4904,6 @@ impl Editor {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clear_code_action_providers(&mut self) {
|
||||
self.code_action_providers.clear();
|
||||
self.available_code_actions.take();
|
||||
}
|
||||
|
||||
pub fn push_code_action_provider(
|
||||
&mut self,
|
||||
provider: Arc<dyn CodeActionProvider>,
|
||||
@@ -5030,7 +4991,7 @@ impl Editor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let provider = self.semantics_provider.clone()?;
|
||||
let project = self.project.clone()?;
|
||||
let buffer = self.buffer.read(cx);
|
||||
let newest_selection = self.selections.newest_anchor().clone();
|
||||
let cursor_position = newest_selection.head();
|
||||
@@ -5046,12 +5007,11 @@ impl Editor {
|
||||
.timer(DOCUMENT_HIGHLIGHTS_DEBOUNCE_TIMEOUT)
|
||||
.await;
|
||||
|
||||
let highlights = if let Some(highlights) = cx
|
||||
.update(|cx| {
|
||||
provider.document_highlights(&cursor_buffer, cursor_buffer_position, cx)
|
||||
let highlights = if let Some(highlights) = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.document_highlights(&cursor_buffer, cursor_buffer_position, cx)
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
.log_err()
|
||||
{
|
||||
highlights.await.log_err()
|
||||
} else {
|
||||
@@ -6241,13 +6201,6 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reload_file(&mut self, _: &ReloadFile, cx: &mut ViewContext<Self>) {
|
||||
let Some(project) = self.project.clone() else {
|
||||
return;
|
||||
};
|
||||
self.reload(project, cx).detach_and_notify_err(cx);
|
||||
}
|
||||
|
||||
pub fn revert_selected_hunks(&mut self, _: &RevertSelectedHunks, cx: &mut ViewContext<Self>) {
|
||||
let revert_changes = self.gather_revert_changes(&self.selections.disjoint_anchors(), cx);
|
||||
if !revert_changes.is_empty() {
|
||||
@@ -6260,22 +6213,14 @@ impl Editor {
|
||||
fn apply_selected_diff_hunks(&mut self, _: &ApplyDiffHunk, cx: &mut ViewContext<Self>) {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let hunks = hunks_for_selections(&snapshot, &self.selections.disjoint_anchors());
|
||||
let mut ranges_by_buffer = HashMap::default();
|
||||
self.transact(cx, |editor, cx| {
|
||||
for hunk in hunks {
|
||||
if let Some(buffer) = editor.buffer.read(cx).buffer(hunk.buffer_id) {
|
||||
ranges_by_buffer
|
||||
.entry(buffer.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(hunk.buffer_range.to_offset(buffer.read(cx)));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.merge_into_base(Some(hunk.buffer_range.to_offset(buffer)), cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (buffer, ranges) in ranges_by_buffer {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.merge_into_base(ranges, cx);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -7480,7 +7425,7 @@ impl Editor {
|
||||
.context_menu
|
||||
.write()
|
||||
.as_mut()
|
||||
.map(|menu| menu.select_first(self.completion_provider.as_deref(), cx))
|
||||
.map(|menu| menu.select_first(self.project.as_ref(), cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
@@ -7589,7 +7534,7 @@ impl Editor {
|
||||
.context_menu
|
||||
.write()
|
||||
.as_mut()
|
||||
.map(|menu| menu.select_last(self.completion_provider.as_deref(), cx))
|
||||
.map(|menu| menu.select_last(self.project.as_ref(), cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
@@ -7641,25 +7586,25 @@ impl Editor {
|
||||
|
||||
pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_first(self.completion_provider.as_deref(), cx);
|
||||
context_menu.select_first(self.project.as_ref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_prev(self.completion_provider.as_deref(), cx);
|
||||
context_menu.select_prev(self.project.as_ref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_next(self.completion_provider.as_deref(), cx);
|
||||
context_menu.select_next(self.project.as_ref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.write().as_mut() {
|
||||
context_menu.select_last(self.completion_provider.as_deref(), cx);
|
||||
context_menu.select_last(self.project.as_ref(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9152,29 +9097,23 @@ impl Editor {
|
||||
.map(|file| (file.worktree_id(cx), file.clone()))
|
||||
.unzip();
|
||||
|
||||
(
|
||||
project.task_store().read(cx).task_inventory().cloned(),
|
||||
worktree_id,
|
||||
file,
|
||||
)
|
||||
(project.task_inventory().clone(), worktree_id, file)
|
||||
});
|
||||
|
||||
let inventory = inventory.read(cx);
|
||||
let tags = mem::take(&mut runnable.tags);
|
||||
let mut tags: Vec<_> = tags
|
||||
.into_iter()
|
||||
.flat_map(|tag| {
|
||||
let tag = tag.0.clone();
|
||||
inventory
|
||||
.as_ref()
|
||||
.list_tasks(
|
||||
file.clone(),
|
||||
Some(runnable.language.clone()),
|
||||
worktree_id,
|
||||
cx,
|
||||
)
|
||||
.into_iter()
|
||||
.flat_map(|inventory| {
|
||||
inventory.read(cx).list_tasks(
|
||||
file.clone(),
|
||||
Some(runnable.language.clone()),
|
||||
worktree_id,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.filter(move |(_, template)| {
|
||||
template.tags.iter().any(|source_tag| source_tag == &tag)
|
||||
})
|
||||
@@ -9632,7 +9571,7 @@ impl Editor {
|
||||
split: bool,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Task<Result<Navigated>> {
|
||||
let Some(provider) = self.semantics_provider.clone() else {
|
||||
let Some(workspace) = self.workspace() else {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
let buffer = self.buffer.read(cx);
|
||||
@@ -9643,9 +9582,13 @@ impl Editor {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
|
||||
let Some(definitions) = provider.definitions(&buffer, head, kind, cx) else {
|
||||
return Task::ready(Ok(Navigated::No));
|
||||
};
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let definitions = project.update(cx, |project, cx| match kind {
|
||||
GotoDefinitionKind::Symbol => project.definition(&buffer, head, cx),
|
||||
GotoDefinitionKind::Declaration => project.declaration(&buffer, head, cx),
|
||||
GotoDefinitionKind::Type => project.type_definition(&buffer, head, cx),
|
||||
GotoDefinitionKind::Implementation => project.implementation(&buffer, head, cx),
|
||||
});
|
||||
|
||||
cx.spawn(|editor, mut cx| async move {
|
||||
let definitions = definitions.await?;
|
||||
@@ -9702,7 +9645,9 @@ impl Editor {
|
||||
return;
|
||||
};
|
||||
|
||||
let project = self.project.clone();
|
||||
let Some(project) = self.project.clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let result = find_file(&buffer, project, buffer_position, &mut cx).await;
|
||||
@@ -10104,7 +10049,7 @@ impl Editor {
|
||||
pub fn rename(&mut self, _: &Rename, cx: &mut ViewContext<Self>) -> Option<Task<Result<()>>> {
|
||||
use language::ToOffset as _;
|
||||
|
||||
let provider = self.semantics_provider.clone()?;
|
||||
let project = self.project.clone()?;
|
||||
let selection = self.selections.newest_anchor().clone();
|
||||
let (cursor_buffer, cursor_buffer_position) = self
|
||||
.buffer
|
||||
@@ -10121,9 +10066,9 @@ impl Editor {
|
||||
let snapshot = cursor_buffer.read(cx).snapshot();
|
||||
let cursor_buffer_offset = cursor_buffer_position.to_offset(&snapshot);
|
||||
let cursor_buffer_offset_end = cursor_buffer_position_end.to_offset(&snapshot);
|
||||
let prepare_rename = provider
|
||||
.range_for_rename(&cursor_buffer, cursor_buffer_position, cx)
|
||||
.unwrap_or_else(|| Task::ready(Ok(None)));
|
||||
let prepare_rename = project.update(cx, |project, cx| {
|
||||
project.prepare_rename(cursor_buffer.clone(), cursor_buffer_offset, cx)
|
||||
});
|
||||
drop(snapshot);
|
||||
|
||||
Some(cx.spawn(|this, mut cx| async move {
|
||||
@@ -10294,28 +10239,32 @@ impl Editor {
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
let rename = self.take_rename(false, cx)?;
|
||||
let workspace = self.workspace()?.downgrade();
|
||||
let (buffer, start) = self
|
||||
let workspace = self.workspace()?;
|
||||
let (start_buffer, start) = self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(rename.range.start, cx)?;
|
||||
let (end_buffer, _) = self
|
||||
let (end_buffer, end) = self
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(rename.range.end, cx)?;
|
||||
if buffer != end_buffer {
|
||||
if start_buffer != end_buffer {
|
||||
return None;
|
||||
}
|
||||
|
||||
let buffer = start_buffer;
|
||||
let range = start..end;
|
||||
let old_name = rename.old_name;
|
||||
let new_name = rename.editor.read(cx).text(cx);
|
||||
|
||||
let rename = self.semantics_provider.as_ref()?.perform_rename(
|
||||
&buffer,
|
||||
start,
|
||||
new_name.clone(),
|
||||
cx,
|
||||
)?;
|
||||
let rename = workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.clone()
|
||||
.update(cx, |project, cx| {
|
||||
project.perform_rename(buffer.clone(), range.start, new_name.clone(), true, cx)
|
||||
});
|
||||
let workspace = workspace.downgrade();
|
||||
|
||||
Some(cx.spawn(|editor, mut cx| async move {
|
||||
let project_transaction = rename.await?;
|
||||
@@ -12189,14 +12138,9 @@ impl Editor {
|
||||
}
|
||||
|
||||
let Some(project) = &self.project else { return };
|
||||
let (telemetry, is_via_ssh) = {
|
||||
let project = project.read(cx);
|
||||
let telemetry = project.client().telemetry().clone();
|
||||
let is_via_ssh = project.is_via_ssh();
|
||||
(telemetry, is_via_ssh)
|
||||
};
|
||||
let telemetry = project.read(cx).client().telemetry().clone();
|
||||
refresh_linked_ranges(self, cx);
|
||||
telemetry.log_edit_event("editor", is_via_ssh);
|
||||
telemetry.log_edit_event("editor");
|
||||
}
|
||||
multi_buffer::Event::ExcerptsAdded {
|
||||
buffer,
|
||||
@@ -12370,22 +12314,14 @@ impl Editor {
|
||||
|
||||
let mut new_selections_by_buffer = HashMap::default();
|
||||
for selection in self.selections.all::<usize>(cx) {
|
||||
for (mut buffer_handle, mut range, _) in
|
||||
buffer.range_to_buffer_ranges(selection.range(), cx)
|
||||
for (buffer, mut range, _) in
|
||||
buffer.range_to_buffer_ranges(selection.start..selection.end, cx)
|
||||
{
|
||||
// When editing branch buffers, jump to the corresponding location
|
||||
// in their base buffer.
|
||||
let buffer = buffer_handle.read(cx);
|
||||
if let Some(base_buffer) = buffer.diff_base_buffer() {
|
||||
range = buffer.range_to_version(range, &base_buffer.read(cx).version());
|
||||
buffer_handle = base_buffer;
|
||||
}
|
||||
|
||||
if selection.reversed {
|
||||
mem::swap(&mut range.start, &mut range.end);
|
||||
}
|
||||
new_selections_by_buffer
|
||||
.entry(buffer_handle)
|
||||
.entry(buffer)
|
||||
.or_insert(Vec::new())
|
||||
.push(range)
|
||||
}
|
||||
@@ -12545,15 +12481,13 @@ impl Editor {
|
||||
.settings_at(0, cx)
|
||||
.show_inline_completions;
|
||||
|
||||
let project = project.read(cx);
|
||||
let telemetry = project.client().telemetry().clone();
|
||||
let telemetry = project.read(cx).client().telemetry().clone();
|
||||
telemetry.report_editor_event(
|
||||
file_extension,
|
||||
vim_mode,
|
||||
operation,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
project.is_via_ssh(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -12670,13 +12604,24 @@ impl Editor {
|
||||
}
|
||||
|
||||
pub fn supports_inlay_hints(&self, cx: &AppContext) -> bool {
|
||||
let Some(provider) = self.semantics_provider.as_ref() else {
|
||||
let Some(project) = self.project.as_ref() else {
|
||||
return false;
|
||||
};
|
||||
let project = project.read(cx);
|
||||
|
||||
let mut supports = false;
|
||||
self.buffer().read(cx).for_each_buffer(|buffer| {
|
||||
supports |= provider.supports_inlay_hints(buffer, cx);
|
||||
if !supports {
|
||||
supports = project
|
||||
.language_servers_for_buffer(buffer.read(cx), cx)
|
||||
.any(
|
||||
|(_, server)| match server.capabilities().inlay_hint_provider {
|
||||
Some(lsp::OneOf::Left(enabled)) => enabled,
|
||||
Some(lsp::OneOf::Right(_)) => true,
|
||||
None => false,
|
||||
},
|
||||
)
|
||||
}
|
||||
});
|
||||
supports
|
||||
}
|
||||
@@ -12942,62 +12887,6 @@ impl CollaborationHub for Model<Project> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SemanticsProvider {
|
||||
fn hover(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Vec<project::Hover>>>;
|
||||
|
||||
fn inlay_hints(
|
||||
&self,
|
||||
buffer_handle: Model<Buffer>,
|
||||
range: Range<text::Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<Vec<InlayHint>>>>;
|
||||
|
||||
fn resolve_inlay_hint(
|
||||
&self,
|
||||
hint: InlayHint,
|
||||
buffer_handle: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<InlayHint>>>;
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Model<Buffer>, cx: &AppContext) -> bool;
|
||||
|
||||
fn document_highlights(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<DocumentHighlight>>>>;
|
||||
|
||||
fn definitions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
kind: GotoDefinitionKind,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<LocationLink>>>>;
|
||||
|
||||
fn range_for_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Option<Range<text::Anchor>>>>>;
|
||||
|
||||
fn perform_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
new_name: String,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<ProjectTransaction>>>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider {
|
||||
fn completions(
|
||||
&self,
|
||||
@@ -13094,11 +12983,18 @@ fn snippet_completions(
|
||||
return vec![];
|
||||
}
|
||||
let snapshot = buffer.read(cx).text_snapshot();
|
||||
let chars = snapshot.reversed_chars_for_range(text::Anchor::MIN..buffer_position);
|
||||
let chunks = snapshot.reversed_chunks_in_range(text::Anchor::MIN..buffer_position);
|
||||
|
||||
let mut lines = chunks.lines();
|
||||
let Some(line_at) = lines.next().filter(|line| !line.is_empty()) else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let scope = language.map(|language| language.default_scope());
|
||||
let classifier = CharClassifier::new(scope).for_completion(true);
|
||||
let mut last_word = chars
|
||||
let mut last_word = line_at
|
||||
.chars()
|
||||
.rev()
|
||||
.take_while(|c| classifier.is_word(*c))
|
||||
.collect::<String>();
|
||||
last_word = last_word.chars().rev().collect();
|
||||
@@ -13242,102 +13138,6 @@ impl CompletionProvider for Model<Project> {
|
||||
}
|
||||
}
|
||||
|
||||
impl SemanticsProvider for Model<Project> {
|
||||
fn hover(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Vec<project::Hover>>> {
|
||||
Some(self.update(cx, |project, cx| project.hover(buffer, position, cx)))
|
||||
}
|
||||
|
||||
fn document_highlights(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<DocumentHighlight>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.document_highlights(buffer, position, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn definitions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
kind: GotoDefinitionKind,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Vec<LocationLink>>>> {
|
||||
Some(self.update(cx, |project, cx| match kind {
|
||||
GotoDefinitionKind::Symbol => project.definition(&buffer, position, cx),
|
||||
GotoDefinitionKind::Declaration => project.declaration(&buffer, position, cx),
|
||||
GotoDefinitionKind::Type => project.type_definition(&buffer, position, cx),
|
||||
GotoDefinitionKind::Implementation => project.implementation(&buffer, position, cx),
|
||||
}))
|
||||
}
|
||||
|
||||
fn supports_inlay_hints(&self, buffer: &Model<Buffer>, cx: &AppContext) -> bool {
|
||||
// TODO: make this work for remote projects
|
||||
self.read(cx)
|
||||
.language_servers_for_buffer(buffer.read(cx), cx)
|
||||
.any(
|
||||
|(_, server)| match server.capabilities().inlay_hint_provider {
|
||||
Some(lsp::OneOf::Left(enabled)) => enabled,
|
||||
Some(lsp::OneOf::Right(_)) => true,
|
||||
None => false,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn inlay_hints(
|
||||
&self,
|
||||
buffer_handle: Model<Buffer>,
|
||||
range: Range<text::Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<Vec<InlayHint>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.inlay_hints(buffer_handle, range, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn resolve_inlay_hint(
|
||||
&self,
|
||||
hint: InlayHint,
|
||||
buffer_handle: Model<Buffer>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<anyhow::Result<InlayHint>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.resolve_inlay_hint(hint, buffer_handle, server_id, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn range_for_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<Option<Range<text::Anchor>>>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.prepare_rename(buffer.clone(), position, cx)
|
||||
}))
|
||||
}
|
||||
|
||||
fn perform_rename(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: text::Anchor,
|
||||
new_name: String,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Task<Result<ProjectTransaction>>> {
|
||||
Some(self.update(cx, |project, cx| {
|
||||
project.perform_rename(buffer.clone(), position, new_name, cx)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn inlay_hint_settings(
|
||||
location: Anchor,
|
||||
snapshot: &MultiBufferSnapshot,
|
||||
@@ -13631,7 +13431,6 @@ pub enum EditorEvent {
|
||||
TransactionBegun {
|
||||
transaction_id: clock::Lamport,
|
||||
},
|
||||
Reloaded,
|
||||
CursorShapeChanged,
|
||||
}
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ pub struct EditorSettingsContent {
|
||||
/// Default: true
|
||||
pub cursor_blink: Option<bool>,
|
||||
/// Cursor shape for the default editor.
|
||||
/// Can be "bar", "block", "underline", or "hollow".
|
||||
/// Can be "bar", "block", "underscore", or "hollow".
|
||||
///
|
||||
/// Default: None
|
||||
pub cursor_shape: Option<CursorShape>,
|
||||
|
||||
@@ -11715,60 +11715,6 @@ async fn test_toggle_diff_expand_in_multi_buffer(cx: &mut gpui::TestAppContext)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_expand_diff_hunk_at_excerpt_boundary(cx: &mut gpui::TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let base = "aaa\nbbb\nccc\nddd\neee\nfff\nggg\n";
|
||||
let text = "aaa\nBBB\nBB2\nccc\nDDD\nEEE\nfff\nggg\n";
|
||||
|
||||
let buffer = cx.new_model(|cx| {
|
||||
let mut buffer = Buffer::local(text.to_string(), cx);
|
||||
buffer.set_diff_base(Some(base.into()), cx);
|
||||
buffer
|
||||
});
|
||||
|
||||
let multi_buffer = cx.new_model(|cx| {
|
||||
let mut multibuffer = MultiBuffer::new(ReadWrite);
|
||||
multibuffer.push_excerpts(
|
||||
buffer.clone(),
|
||||
[
|
||||
ExcerptRange {
|
||||
context: Point::new(0, 0)..Point::new(2, 0),
|
||||
primary: None,
|
||||
},
|
||||
ExcerptRange {
|
||||
context: Point::new(5, 0)..Point::new(7, 0),
|
||||
primary: None,
|
||||
},
|
||||
],
|
||||
cx,
|
||||
);
|
||||
multibuffer
|
||||
});
|
||||
|
||||
let editor = cx.add_window(|cx| Editor::new(EditorMode::Full, multi_buffer, None, true, cx));
|
||||
let mut cx = EditorTestContext::for_editor(editor, cx).await;
|
||||
cx.run_until_parked();
|
||||
|
||||
cx.update_editor(|editor, cx| editor.expand_all_hunk_diffs(&Default::default(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
cx.assert_diff_hunks(
|
||||
"
|
||||
aaa
|
||||
- bbb
|
||||
+ BBB
|
||||
|
||||
- ddd
|
||||
- eee
|
||||
+ EEE
|
||||
fff
|
||||
"
|
||||
.unindent(),
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edits_around_expanded_insertion_hunks(
|
||||
executor: BackgroundExecutor,
|
||||
|
||||
@@ -64,7 +64,7 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use sum_tree::Bias;
|
||||
use theme::{ActiveTheme, Appearance, PlayerColor};
|
||||
use theme::{ActiveTheme, PlayerColor};
|
||||
use ui::prelude::*;
|
||||
use ui::{h_flex, ButtonLike, ButtonStyle, ContextMenu, Tooltip};
|
||||
use util::RangeExt;
|
||||
@@ -437,8 +437,7 @@ impl EditorElement {
|
||||
register_action(view, cx, Editor::revert_file);
|
||||
register_action(view, cx, Editor::revert_selected_hunks);
|
||||
register_action(view, cx, Editor::apply_selected_diff_hunks);
|
||||
register_action(view, cx, Editor::open_active_item_in_terminal);
|
||||
register_action(view, cx, Editor::reload_file)
|
||||
register_action(view, cx, Editor::open_active_item_in_terminal)
|
||||
}
|
||||
|
||||
fn register_key_listeners(&self, cx: &mut WindowContext, layout: &EditorLayout) {
|
||||
@@ -1016,20 +1015,8 @@ impl EditorElement {
|
||||
block_width = em_width;
|
||||
}
|
||||
let block_text = if let CursorShape::Block = selection.cursor_shape {
|
||||
snapshot
|
||||
.display_chars_at(cursor_position)
|
||||
.next()
|
||||
.or_else(|| {
|
||||
if cursor_column == 0 {
|
||||
snapshot
|
||||
.placeholder_text()
|
||||
.and_then(|s| s.chars().next())
|
||||
.map(|c| (c, cursor_position))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.and_then(|(character, _)| {
|
||||
snapshot.display_chars_at(cursor_position).next().and_then(
|
||||
|(character, _)| {
|
||||
let text = if character == '\n' {
|
||||
SharedString::from(" ")
|
||||
} else {
|
||||
@@ -1044,22 +1031,6 @@ impl EditorElement {
|
||||
})
|
||||
.unwrap_or(self.style.text.font());
|
||||
|
||||
// Invert the text color for the block cursor. Ensure that the text
|
||||
// color is opaque enough to be visible against the background color.
|
||||
//
|
||||
// 0.75 is an arbitrary threshold to determine if the background color is
|
||||
// opaque enough to use as a text color.
|
||||
//
|
||||
// TODO: In the future we should ensure themes have a `text_inverse` color.
|
||||
let color = if cx.theme().colors().editor_background.a < 0.75 {
|
||||
match cx.theme().appearance {
|
||||
Appearance::Dark => Hsla::black(),
|
||||
Appearance::Light => Hsla::white(),
|
||||
}
|
||||
} else {
|
||||
cx.theme().colors().editor_background
|
||||
};
|
||||
|
||||
cx.text_system()
|
||||
.shape_line(
|
||||
text,
|
||||
@@ -1067,14 +1038,15 @@ impl EditorElement {
|
||||
&[TextRun {
|
||||
len,
|
||||
font,
|
||||
color,
|
||||
color: self.style.background,
|
||||
background_color: None,
|
||||
strikethrough: None,
|
||||
underline: None,
|
||||
}],
|
||||
)
|
||||
.log_err()
|
||||
})
|
||||
},
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -6088,7 +6060,7 @@ impl CursorLayout {
|
||||
origin: self.origin + origin,
|
||||
size: size(self.block_width, self.line_height),
|
||||
},
|
||||
CursorShape::Underline => Bounds {
|
||||
CursorShape::Underscore => Bounds {
|
||||
origin: self.origin
|
||||
+ origin
|
||||
+ gpui::Point::new(Pixels::ZERO, self.line_height - px(2.0)),
|
||||
|
||||
@@ -403,10 +403,7 @@ impl GitBlame {
|
||||
if this.user_triggered {
|
||||
log::error!("failed to get git blame data: {error:?}");
|
||||
let notification = format!("{:#}", error).trim().to_string();
|
||||
cx.emit(project::Event::Toast {
|
||||
notification_id: "git-blame".into(),
|
||||
message: notification,
|
||||
});
|
||||
cx.emit(project::Event::Notification(notification));
|
||||
} else {
|
||||
// If we weren't triggered by a user, we just log errors in the background, instead of sending
|
||||
// notifications.
|
||||
@@ -622,11 +619,9 @@ mod tests {
|
||||
let event = project.next_event(cx).await;
|
||||
assert_eq!(
|
||||
event,
|
||||
project::Event::Toast {
|
||||
notification_id: "git-blame".into(),
|
||||
message: "Failed to blame \"file.txt\": failed to get blame for \"file.txt\""
|
||||
.to_string()
|
||||
}
|
||||
project::Event::Notification(
|
||||
"Failed to blame \"file.txt\": failed to get blame for \"file.txt\"".to_string()
|
||||
)
|
||||
);
|
||||
|
||||
blame.update(cx, |blame, cx| {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
hover_popover::{self, InlayHover},
|
||||
scroll::ScrollAmount,
|
||||
Anchor, Editor, EditorSnapshot, FindAllReferences, GoToDefinition, GoToTypeDefinition,
|
||||
GotoDefinitionKind, InlayId, Navigated, PointForPosition, SelectPhase,
|
||||
Anchor, Editor, EditorSnapshot, FindAllReferences, GoToDefinition, GoToTypeDefinition, InlayId,
|
||||
Navigated, PointForPosition, SelectPhase,
|
||||
};
|
||||
use gpui::{px, AppContext, AsyncWindowContext, Model, Modifiers, Task, ViewContext};
|
||||
use language::{Bias, ToOffset};
|
||||
@@ -14,12 +14,12 @@ use project::{
|
||||
};
|
||||
use std::ops::Range;
|
||||
use theme::ActiveTheme as _;
|
||||
use util::{maybe, ResultExt, TryFutureExt as _};
|
||||
use util::{maybe, ResultExt, TryFutureExt};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HoveredLinkState {
|
||||
pub last_trigger_point: TriggerPoint,
|
||||
pub preferred_kind: GotoDefinitionKind,
|
||||
pub preferred_kind: LinkDefinitionKind,
|
||||
pub symbol_range: Option<RangeInEditor>,
|
||||
pub links: Vec<HoverLink>,
|
||||
pub task: Option<Task<Option<()>>>,
|
||||
@@ -428,6 +428,12 @@ pub fn update_inlay_link_and_hover_points(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum LinkDefinitionKind {
|
||||
Symbol,
|
||||
Type,
|
||||
}
|
||||
|
||||
pub fn show_link_definition(
|
||||
shift_held: bool,
|
||||
editor: &mut Editor,
|
||||
@@ -436,8 +442,8 @@ pub fn show_link_definition(
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) {
|
||||
let preferred_kind = match trigger_point {
|
||||
TriggerPoint::Text(_) if !shift_held => GotoDefinitionKind::Symbol,
|
||||
_ => GotoDefinitionKind::Type,
|
||||
TriggerPoint::Text(_) if !shift_held => LinkDefinitionKind::Symbol,
|
||||
_ => LinkDefinitionKind::Type,
|
||||
};
|
||||
|
||||
let (mut hovered_link_state, is_cached) =
|
||||
@@ -499,7 +505,6 @@ pub fn show_link_definition(
|
||||
editor.hide_hovered_link(cx)
|
||||
}
|
||||
let project = editor.project.clone();
|
||||
let provider = editor.semantics_provider.clone();
|
||||
|
||||
let snapshot = snapshot.buffer_snapshot.clone();
|
||||
hovered_link_state.task = Some(cx.spawn(|this, mut cx| {
|
||||
@@ -517,40 +522,54 @@ pub fn show_link_definition(
|
||||
(range, vec![HoverLink::Url(url)])
|
||||
})
|
||||
.ok()
|
||||
} else if let Some((filename_range, filename)) =
|
||||
find_file(&buffer, project.clone(), buffer_position, &mut cx).await
|
||||
{
|
||||
let range = maybe!({
|
||||
let start =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.start)?;
|
||||
let end = snapshot.anchor_in_excerpt(excerpt_id, filename_range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
});
|
||||
} else if let Some(project) = project {
|
||||
if let Some((filename_range, filename)) =
|
||||
find_file(&buffer, project.clone(), buffer_position, &mut cx).await
|
||||
{
|
||||
let range = maybe!({
|
||||
let start =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.start)?;
|
||||
let end =
|
||||
snapshot.anchor_in_excerpt(excerpt_id, filename_range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
});
|
||||
|
||||
Some((range, vec![HoverLink::File(filename)]))
|
||||
} else if let Some(provider) = provider {
|
||||
let task = cx.update(|cx| {
|
||||
provider.definitions(&buffer, buffer_position, preferred_kind, cx)
|
||||
})?;
|
||||
if let Some(task) = task {
|
||||
task.await.ok().map(|definition_result| {
|
||||
(
|
||||
definition_result.iter().find_map(|link| {
|
||||
link.origin.as_ref().and_then(|origin| {
|
||||
let start = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.start,
|
||||
)?;
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, origin.range.end)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
})
|
||||
}),
|
||||
definition_result.into_iter().map(HoverLink::Text).collect(),
|
||||
)
|
||||
})
|
||||
Some((range, vec![HoverLink::File(filename)]))
|
||||
} else {
|
||||
None
|
||||
// query the LSP for definition info
|
||||
project
|
||||
.update(&mut cx, |project, cx| match preferred_kind {
|
||||
LinkDefinitionKind::Symbol => {
|
||||
project.definition(&buffer, buffer_position, cx)
|
||||
}
|
||||
|
||||
LinkDefinitionKind::Type => {
|
||||
project.type_definition(&buffer, buffer_position, cx)
|
||||
}
|
||||
})?
|
||||
.await
|
||||
.ok()
|
||||
.map(|definition_result| {
|
||||
(
|
||||
definition_result.iter().find_map(|link| {
|
||||
link.origin.as_ref().and_then(|origin| {
|
||||
let start = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.start,
|
||||
)?;
|
||||
let end = snapshot.anchor_in_excerpt(
|
||||
excerpt_id,
|
||||
origin.range.end,
|
||||
)?;
|
||||
Some(RangeInEditor::Text(start..end))
|
||||
})
|
||||
}),
|
||||
definition_result
|
||||
.into_iter()
|
||||
.map(HoverLink::Text)
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
None
|
||||
@@ -689,11 +708,10 @@ pub(crate) fn find_url(
|
||||
|
||||
pub(crate) async fn find_file(
|
||||
buffer: &Model<language::Buffer>,
|
||||
project: Option<Model<Project>>,
|
||||
project: Model<Project>,
|
||||
position: text::Anchor,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Option<(Range<text::Anchor>, ResolvedPath)> {
|
||||
let project = project?;
|
||||
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()).ok()?;
|
||||
let scope = snapshot.language_scope_at(position);
|
||||
let (range, candidate_file_path) = surrounding_filename(snapshot, position)?;
|
||||
|
||||
@@ -195,22 +195,32 @@ fn show_hover(
|
||||
anchor: Anchor,
|
||||
ignore_timeout: bool,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> Option<()> {
|
||||
) {
|
||||
if editor.pending_rename.is_some() {
|
||||
return None;
|
||||
return;
|
||||
}
|
||||
|
||||
let snapshot = editor.snapshot(cx);
|
||||
|
||||
let (buffer, buffer_position) = editor
|
||||
.buffer
|
||||
.read(cx)
|
||||
.text_anchor_for_position(anchor, cx)?;
|
||||
let (buffer, buffer_position) =
|
||||
if let Some(output) = editor.buffer.read(cx).text_anchor_for_position(anchor, cx) {
|
||||
output
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
let (excerpt_id, _, _) = editor.buffer().read(cx).excerpt_containing(anchor, cx)?;
|
||||
let excerpt_id =
|
||||
if let Some((excerpt_id, _, _)) = editor.buffer().read(cx).excerpt_containing(anchor, cx) {
|
||||
excerpt_id
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
let language_registry = editor.project.as_ref()?.read(cx).languages().clone();
|
||||
let provider = editor.semantics_provider.clone()?;
|
||||
let project = if let Some(project) = editor.project.clone() {
|
||||
project
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
if !ignore_timeout {
|
||||
if same_info_hover(editor, &snapshot, anchor)
|
||||
@@ -218,7 +228,7 @@ fn show_hover(
|
||||
|| editor.hover_state.diagnostic_popover.is_some()
|
||||
{
|
||||
// Hover triggered from same location as last time. Don't show again.
|
||||
return None;
|
||||
return;
|
||||
} else {
|
||||
hide_hover(editor, cx);
|
||||
}
|
||||
@@ -230,7 +240,7 @@ fn show_hover(
|
||||
.cmp(&anchor, &snapshot.buffer_snapshot)
|
||||
.is_eq()
|
||||
{
|
||||
return None;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,7 +262,12 @@ fn show_hover(
|
||||
total_delay
|
||||
};
|
||||
|
||||
let hover_request = cx.update(|cx| provider.hover(&buffer, buffer_position, cx))?;
|
||||
// query the LSP for hover info
|
||||
let hover_request = cx.update(|cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
project.hover(&buffer, buffer_position, cx)
|
||||
})
|
||||
})?;
|
||||
|
||||
if let Some(delay) = delay {
|
||||
delay.await;
|
||||
@@ -362,11 +377,8 @@ fn show_hover(
|
||||
this.hover_state.diagnostic_popover = diagnostic_popover;
|
||||
})?;
|
||||
|
||||
let hovers_response = if let Some(hover_request) = hover_request {
|
||||
hover_request.await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let hovers_response = hover_request.await;
|
||||
let language_registry = project.update(&mut cx, |p, _| p.languages().clone())?;
|
||||
let snapshot = this.update(&mut cx, |this, cx| this.snapshot(cx))?;
|
||||
let mut hover_highlights = Vec::with_capacity(hovers_response.len());
|
||||
let mut info_popovers = Vec::with_capacity(hovers_response.len());
|
||||
@@ -439,7 +451,6 @@ fn show_hover(
|
||||
});
|
||||
|
||||
editor.hover_state.info_task = Some(task);
|
||||
None
|
||||
}
|
||||
|
||||
fn same_info_hover(editor: &Editor, snapshot: &EditorSnapshot, anchor: Anchor) -> bool {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user