Compare commits
74 Commits
agent-perf
...
ep-example
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8953b487ad | ||
|
|
196c488ed4 | ||
|
|
dfbbacec12 | ||
|
|
9161a23513 | ||
|
|
9a8ccb32ac | ||
|
|
5cfdfd32c6 | ||
|
|
defcc2f51b | ||
|
|
6ebe0edea0 | ||
|
|
1a83c0f5e4 | ||
|
|
27a6d54efe | ||
|
|
a168d8f50a | ||
|
|
e243a658a5 | ||
|
|
a93fd51f35 | ||
|
|
0dcdc6d9a4 | ||
|
|
7e09b59fa3 | ||
|
|
1e28bf8279 | ||
|
|
b6eec44a99 | ||
|
|
d83c985923 | ||
|
|
74c4e25b8c | ||
|
|
2021f32947 | ||
|
|
299ca2e8ac | ||
|
|
c284f9086b | ||
|
|
fc89e19098 | ||
|
|
f53b01d5a2 | ||
|
|
bf1c8819d9 | ||
|
|
3247264288 | ||
|
|
6d947b7746 | ||
|
|
db221ca72d | ||
|
|
1d006a8cb0 | ||
|
|
aaab9f6960 | ||
|
|
209cf0a48f | ||
|
|
260691c99c | ||
|
|
9e88f3f33c | ||
|
|
2cad6c8ef1 | ||
|
|
bc24ffe863 | ||
|
|
1e4a970ae2 | ||
|
|
3e656a0911 | ||
|
|
57ea23d161 | ||
|
|
a50c5b2c10 | ||
|
|
f1b723973b | ||
|
|
a7ce677ac3 | ||
|
|
ed67f246cb | ||
|
|
93f29326c4 | ||
|
|
85f4681299 | ||
|
|
741c5d5010 | ||
|
|
f03987fb68 | ||
|
|
ca47822667 | ||
|
|
a34fe06bb1 | ||
|
|
0ce484e66c | ||
|
|
251033f88f | ||
|
|
9f90c1a1b7 | ||
|
|
d43cc46288 | ||
|
|
fdb8e71b43 | ||
|
|
6bc433ed43 | ||
|
|
1281f4672c | ||
|
|
ed705c0cbc | ||
|
|
8980333e23 | ||
|
|
acee48bfda | ||
|
|
71298e6949 | ||
|
|
07ada58466 | ||
|
|
dd521a96fb | ||
|
|
f9d9721b93 | ||
|
|
cff3ac6f93 | ||
|
|
746b76488c | ||
|
|
397fcf6083 | ||
|
|
9adb3e1daa | ||
|
|
1469d94683 | ||
|
|
3b626c8ac1 | ||
|
|
3dc0614dba | ||
|
|
045e154915 | ||
|
|
dc72e1c4ba | ||
|
|
0884305e43 | ||
|
|
83449293b6 | ||
|
|
213cb30445 |
1
.github/actionlint.yml
vendored
1
.github/actionlint.yml
vendored
@@ -25,6 +25,7 @@ self-hosted-runner:
|
||||
- namespace-profile-32x64-ubuntu-2204
|
||||
# Namespace Ubuntu 24.04 (like ubuntu-latest)
|
||||
- namespace-profile-2x4-ubuntu-2404
|
||||
- namespace-profile-8x32-ubuntu-2404
|
||||
# Namespace Limited Preview
|
||||
- namespace-profile-8x16-ubuntu-2004-arm-m4
|
||||
- namespace-profile-8x32-ubuntu-2004-arm-m4
|
||||
|
||||
4
.github/workflows/extension_bump.yml
vendored
4
.github/workflows/extension_bump.yml
vendored
@@ -66,7 +66,7 @@ jobs:
|
||||
if: |-
|
||||
(github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') &&
|
||||
(inputs.force-bump == 'true' || needs.check_bump_needed.outputs.needs_bump == 'true')
|
||||
runs-on: namespace-profile-8x16-ubuntu-2204
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- id: generate-token
|
||||
name: extension_bump::generate_token
|
||||
@@ -119,7 +119,7 @@ jobs:
|
||||
needs:
|
||||
- check_bump_needed
|
||||
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions') && github.event_name == 'push' && github.ref == 'refs/heads/main' && needs.check_bump_needed.outputs.needs_bump == 'false'
|
||||
runs-on: namespace-profile-8x16-ubuntu-2204
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- id: generate-token
|
||||
name: extension_bump::generate_token
|
||||
|
||||
2
.github/workflows/extension_release.yml
vendored
2
.github/workflows/extension_release.yml
vendored
@@ -13,7 +13,7 @@ on:
|
||||
jobs:
|
||||
create_release:
|
||||
if: (github.repository_owner == 'zed-industries' || github.repository_owner == 'zed-extensions')
|
||||
runs-on: namespace-profile-8x16-ubuntu-2204
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- id: generate-token
|
||||
name: extension_bump::generate_token
|
||||
|
||||
4
.github/workflows/extension_tests.yml
vendored
4
.github/workflows/extension_tests.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
needs:
|
||||
- orchestrate
|
||||
if: needs.orchestrate.outputs.check_rust == 'true'
|
||||
runs-on: namespace-profile-16x32-ubuntu-2204
|
||||
runs-on: namespace-profile-4x8-ubuntu-2204
|
||||
steps:
|
||||
- name: steps::checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
@@ -79,7 +79,7 @@ jobs:
|
||||
needs:
|
||||
- orchestrate
|
||||
if: needs.orchestrate.outputs.check_extension == 'true'
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
runs-on: namespace-profile-8x32-ubuntu-2404
|
||||
steps:
|
||||
- name: steps::checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
|
||||
106
.github/workflows/extension_workflow_rollout.yml
vendored
Normal file
106
.github/workflows/extension_workflow_rollout.yml
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
# Generated from xtask::workflows::extension_workflow_rollout
|
||||
# Rebuild with `cargo xtask workflows`.
|
||||
name: extension_workflow_rollout
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
on:
|
||||
workflow_dispatch: {}
|
||||
jobs:
|
||||
fetch_extension_repos:
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
steps:
|
||||
- id: list-repos
|
||||
name: extension_workflow_rollout::fetch_extension_repos::get_repositories
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const repos = await github.paginate(github.rest.repos.listForOrg, {
|
||||
org: 'zed-extensions',
|
||||
type: 'public',
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const filteredRepos = repos
|
||||
.filter(repo => !repo.archived)
|
||||
.filter(repo => repo.name !== 'workflows' && repo.name !== 'material-icon-theme')
|
||||
.map(repo => repo.name);
|
||||
|
||||
console.log(`Found ${filteredRepos.length} extension repos`);
|
||||
return filteredRepos;
|
||||
result-encoding: json
|
||||
outputs:
|
||||
repos: ${{ steps.list-repos.outputs.result }}
|
||||
timeout-minutes: 5
|
||||
rollout_workflows_to_extension:
|
||||
needs:
|
||||
- fetch_extension_repos
|
||||
if: needs.fetch_extension_repos.outputs.repos != '[]'
|
||||
runs-on: namespace-profile-2x4-ubuntu-2404
|
||||
strategy:
|
||||
matrix:
|
||||
repo: ${{ fromJson(needs.fetch_extension_repos.outputs.repos) }}
|
||||
fail-fast: false
|
||||
max-parallel: 5
|
||||
steps:
|
||||
- id: generate-token
|
||||
name: extension_bump::generate_token
|
||||
uses: actions/create-github-app-token@v2
|
||||
with:
|
||||
app-id: ${{ secrets.ZED_ZIPPY_APP_ID }}
|
||||
private-key: ${{ secrets.ZED_ZIPPY_APP_PRIVATE_KEY }}
|
||||
owner: zed-extensions
|
||||
repositories: ${{ matrix.repo }}
|
||||
permission-pull-requests: write
|
||||
permission-contents: write
|
||||
permission-workflows: write
|
||||
- name: checkout_zed_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
with:
|
||||
clean: false
|
||||
path: zed
|
||||
- name: steps::checkout_repo_with_token
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
with:
|
||||
clean: false
|
||||
token: ${{ steps.generate-token.outputs.token }}
|
||||
repository: zed-extensions/${{ matrix.repo }}
|
||||
path: extension
|
||||
- name: extension_workflow_rollout::rollout_workflows_to_extension::copy_workflow_files
|
||||
run: |
|
||||
mkdir -p extension/.github/workflows
|
||||
cp zed/extensions/workflows/shared/*.yml extension/.github/workflows/
|
||||
shell: bash -euxo pipefail {0}
|
||||
- id: short-sha
|
||||
name: extension_workflow_rollout::rollout_workflows_to_extension::get_short_sha
|
||||
run: |
|
||||
echo "sha_short=$(git rev-parse --short HEAD)" >> "$GITHUB_OUTPUT"
|
||||
shell: bash -euxo pipefail {0}
|
||||
working-directory: zed
|
||||
- id: create-pr
|
||||
name: extension_workflow_rollout::rollout_workflows_to_extension::create_pull_request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
path: extension
|
||||
title: Update CI workflows to zed@${{ steps.short-sha.outputs.sha_short }}
|
||||
body: |
|
||||
This PR updates the CI workflow files from the main Zed repository
|
||||
based on the commit zed-industries/zed@${{ github.sha }}
|
||||
commit-message: Update CI workflows to zed@${{ steps.short-sha.outputs.sha_short }}
|
||||
branch: update-workflows
|
||||
committer: zed-zippy[bot] <234243425+zed-zippy[bot]@users.noreply.github.com>
|
||||
author: zed-zippy[bot] <234243425+zed-zippy[bot]@users.noreply.github.com>
|
||||
base: main
|
||||
delete-branch: true
|
||||
token: ${{ steps.generate-token.outputs.token }}
|
||||
sign-commits: true
|
||||
- name: extension_workflow_rollout::rollout_workflows_to_extension::enable_auto_merge
|
||||
run: |
|
||||
PR_NUMBER="${{ steps.create-pr.outputs.pull-request-number }}"
|
||||
if [ -n "$PR_NUMBER" ]; then
|
||||
cd extension
|
||||
gh pr merge "$PR_NUMBER" --auto --squash
|
||||
fi
|
||||
shell: bash -euxo pipefail {0}
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
|
||||
timeout-minutes: 10
|
||||
@@ -6,7 +6,7 @@ on:
|
||||
|
||||
jobs:
|
||||
handle-good-first-issue:
|
||||
if: github.event.label.name == 'good first issue' && github.repository_owner == 'zed-industries'
|
||||
if: github.event.label.name == '.contrib/good first issue' && github.repository_owner == 'zed-industries'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
|
||||
@@ -23,7 +23,6 @@ In particular we love PRs that are:
|
||||
|
||||
If you're looking for concrete ideas:
|
||||
|
||||
- [Curated board of issues](https://github.com/orgs/zed-industries/projects/69) suitable for everyone from first-time contributors to seasoned community champions.
|
||||
- [Triaged bugs with confirmed steps to reproduce](https://github.com/zed-industries/zed/issues?q=is%3Aissue%20state%3Aopen%20type%3ABug%20label%3Astate%3Areproducible).
|
||||
- [Area labels](https://github.com/zed-industries/zed/labels?q=area%3A*) to browse bugs in a specific part of the product you care about (after clicking on an area label, add type:Bug to the search).
|
||||
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -268,6 +268,7 @@ dependencies = [
|
||||
"client",
|
||||
"collections",
|
||||
"env_logger 0.11.8",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
@@ -5212,6 +5213,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"brotli",
|
||||
"buffer_diff",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_api_types",
|
||||
@@ -5249,7 +5251,10 @@ dependencies = [
|
||||
"strum 0.27.2",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
"text",
|
||||
"thiserror 2.0.17",
|
||||
"time",
|
||||
"toml 0.8.23",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
@@ -5354,8 +5359,10 @@ dependencies = [
|
||||
"anyhow",
|
||||
"buffer_diff",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"codestral",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"copilot",
|
||||
"edit_prediction",
|
||||
@@ -5364,18 +5371,20 @@ dependencies = [
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"git",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"language",
|
||||
"log",
|
||||
"language_model",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"paths",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"semver",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"supermaven",
|
||||
@@ -5388,6 +5397,7 @@ dependencies = [
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8645,6 +8655,7 @@ dependencies = [
|
||||
"extension",
|
||||
"gpui",
|
||||
"language",
|
||||
"lsp",
|
||||
"paths",
|
||||
"project",
|
||||
"schemars",
|
||||
@@ -20955,7 +20966,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_glsl"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.1.0",
|
||||
]
|
||||
@@ -20969,7 +20980,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_proto"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.7.0",
|
||||
]
|
||||
|
||||
@@ -241,6 +241,7 @@
|
||||
"ctrl-alt-l": "agent::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"ctrl-shift-j": "agent::ToggleNavigationMenu",
|
||||
"ctrl-alt-i": "agent::ToggleOptionsMenu",
|
||||
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
@@ -253,7 +254,6 @@
|
||||
"ctrl-y": "agent::AllowOnce",
|
||||
"ctrl-alt-y": "agent::AllowAlways",
|
||||
"ctrl-alt-z": "agent::RejectOnce",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -285,38 +285,6 @@
|
||||
"ctrl-alt-t": "agent::NewThread",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && !use_modifier_to_send",
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"ctrl-enter": "agent::ChatWithFollow",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && use_modifier_to_send",
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::Chat",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "EditMessageEditor > Editor",
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "menu::Confirm",
|
||||
"alt-enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AgentFeedbackMessageEditor > Editor",
|
||||
"bindings": {
|
||||
@@ -331,14 +299,25 @@
|
||||
"ctrl-enter": "menu::Confirm",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::ChatWithFollow",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor && !use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -346,11 +325,7 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::Chat",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -817,7 +792,7 @@
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "PromptEditor",
|
||||
"context": "InlineAssistant",
|
||||
"bindings": {
|
||||
"ctrl-[": "agent::CyclePreviousInlineAssist",
|
||||
"ctrl-]": "agent::CycleNextInlineAssist",
|
||||
|
||||
@@ -282,6 +282,7 @@
|
||||
"cmd-alt-p": "agent::ManageProfiles",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"cmd-shift-j": "agent::ToggleNavigationMenu",
|
||||
"cmd-alt-m": "agent::ToggleOptionsMenu",
|
||||
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
@@ -294,7 +295,6 @@
|
||||
"cmd-y": "agent::AllowOnce",
|
||||
"cmd-alt-y": "agent::AllowAlways",
|
||||
"cmd-alt-z": "agent::RejectOnce",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -326,41 +326,6 @@
|
||||
"cmd-alt-t": "agent::NewThread",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && !use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"cmd-enter": "agent::ChatWithFollow",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll",
|
||||
"cmd-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-enter": "agent::Chat",
|
||||
"enter": "editor::Newline",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll",
|
||||
"cmd-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "EditMessageEditor > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "menu::Confirm",
|
||||
"alt-enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AgentFeedbackMessageEditor > Editor",
|
||||
"use_key_equivalents": true,
|
||||
@@ -382,16 +347,25 @@
|
||||
"cmd-enter": "menu::Confirm",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll",
|
||||
"cmd-enter": "agent::ChatWithFollow",
|
||||
"cmd-shift-v": "agent::PasteRaw",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor && !use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -399,11 +373,7 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-enter": "agent::Chat",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -883,7 +853,7 @@
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "PromptEditor",
|
||||
"context": "InlineAssistant > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
|
||||
@@ -241,6 +241,7 @@
|
||||
"shift-alt-l": "agent::OpenRulesLibrary",
|
||||
"shift-alt-p": "agent::ManageProfiles",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"shift-alt-/": "agent::ToggleModelSelector",
|
||||
"shift-alt-j": "agent::ToggleNavigationMenu",
|
||||
"shift-alt-i": "agent::ToggleOptionsMenu",
|
||||
@@ -254,7 +255,6 @@
|
||||
"shift-alt-a": "agent::AllowOnce",
|
||||
"ctrl-alt-y": "agent::AllowAlways",
|
||||
"shift-alt-z": "agent::RejectOnce",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -287,41 +287,6 @@
|
||||
"ctrl-alt-t": "agent::NewThread",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && !use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"ctrl-enter": "agent::ChatWithFollow",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "MessageEditor && !Picker > Editor && use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::Chat",
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "EditMessageEditor > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel",
|
||||
"enter": "menu::Confirm",
|
||||
"alt-enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AgentFeedbackMessageEditor > Editor",
|
||||
"use_key_equivalents": true,
|
||||
@@ -337,16 +302,25 @@
|
||||
"ctrl-enter": "menu::Confirm",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::ChatWithFollow",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"ctrl-shift-v": "agent::PasteRaw",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "AcpThread > Editor && !use_modifier_to_send",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -354,11 +328,7 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-enter": "agent::Chat",
|
||||
"ctrl-shift-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll",
|
||||
"shift-tab": "agent::CycleModeSelector",
|
||||
"alt-tab": "agent::CycleFavoriteModels",
|
||||
"enter": "editor::Newline",
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -826,7 +796,7 @@
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "PromptEditor",
|
||||
"context": "InlineAssistant",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-[": "agent::CyclePreviousInlineAssist",
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "InlineAssistEditor",
|
||||
"context": "InlineAssistant > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-shift-backspace": "editor::Cancel",
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
},
|
||||
},
|
||||
{
|
||||
"context": "InlineAssistEditor",
|
||||
"context": "InlineAssistant > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-shift-backspace": "editor::Cancel",
|
||||
|
||||
@@ -884,6 +884,7 @@ pub enum AcpThreadEvent {
|
||||
Refusal,
|
||||
AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
|
||||
ModeUpdated(acp::SessionModeId),
|
||||
ConfigOptionsUpdated(Vec<acp::SessionConfigOption>),
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
@@ -1193,6 +1194,10 @@ impl AcpThread {
|
||||
current_mode_id,
|
||||
..
|
||||
}) => cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id)),
|
||||
acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
|
||||
config_options,
|
||||
..
|
||||
}) => cx.emit(AcpThreadEvent::ConfigOptionsUpdated(config_options)),
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -86,6 +86,14 @@ pub trait AgentConnection {
|
||||
None
|
||||
}
|
||||
|
||||
fn session_config_options(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &App,
|
||||
) -> Option<Rc<dyn AgentSessionConfigOptions>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||
}
|
||||
|
||||
@@ -125,6 +133,26 @@ pub trait AgentSessionModes {
|
||||
fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
pub trait AgentSessionConfigOptions {
|
||||
/// Get all current config options with their state
|
||||
fn config_options(&self) -> Vec<acp::SessionConfigOption>;
|
||||
|
||||
/// Set a config option value
|
||||
/// Returns the full updated list of config options
|
||||
fn set_config_option(
|
||||
&self,
|
||||
config_id: acp::SessionConfigId,
|
||||
value: acp::SessionConfigValueId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Vec<acp::SessionConfigOption>>>;
|
||||
|
||||
/// Whenever the config options are updated the receiver will be notified.
|
||||
/// Optional for agents that don't update their config options dynamically.
|
||||
fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired {
|
||||
pub description: Option<String>,
|
||||
@@ -202,12 +230,6 @@ pub trait AgentModelSelector: 'static {
|
||||
fn should_render_footer(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Whether this selector supports the favorites feature.
|
||||
/// Only the native agent uses the model ID format that maps to settings.
|
||||
fn supports_favorites(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Icon for a model in the model selector.
|
||||
|
||||
@@ -4,22 +4,20 @@ use std::{
|
||||
fmt::Display,
|
||||
rc::{Rc, Weak},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use collections::HashMap;
|
||||
use gpui::{
|
||||
App, ClipboardItem, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment,
|
||||
ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list,
|
||||
prelude::*,
|
||||
App, Empty, Entity, EventEmitter, FocusHandle, Focusable, Global, ListAlignment, ListState,
|
||||
StyleRefinement, Subscription, Task, TextStyleRefinement, Window, actions, list, prelude::*,
|
||||
};
|
||||
use language::LanguageRegistry;
|
||||
use markdown::{CodeBlockRenderer, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use project::Project;
|
||||
use settings::Settings;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Tooltip, WithScrollbar, prelude::*};
|
||||
use ui::{CopyButton, Tooltip, WithScrollbar, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{
|
||||
Item, ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
|
||||
@@ -544,15 +542,11 @@ impl Render for AcpTools {
|
||||
|
||||
pub struct AcpToolsToolbarItemView {
|
||||
acp_tools: Option<Entity<AcpTools>>,
|
||||
just_copied: bool,
|
||||
}
|
||||
|
||||
impl AcpToolsToolbarItemView {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
acp_tools: None,
|
||||
just_copied: false,
|
||||
}
|
||||
Self { acp_tools: None }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,37 +566,14 @@ impl Render for AcpToolsToolbarItemView {
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child({
|
||||
let acp_tools = acp_tools.clone();
|
||||
IconButton::new(
|
||||
"copy_all_messages",
|
||||
if self.just_copied {
|
||||
IconName::Check
|
||||
} else {
|
||||
IconName::Copy
|
||||
},
|
||||
)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text(if self.just_copied {
|
||||
"Copied!"
|
||||
} else {
|
||||
"Copy All Messages"
|
||||
}))
|
||||
.disabled(!has_messages)
|
||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||
if let Some(content) = acp_tools.read(cx).serialize_observed_messages() {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(content));
|
||||
let message = acp_tools
|
||||
.read(cx)
|
||||
.serialize_observed_messages()
|
||||
.unwrap_or_default();
|
||||
|
||||
this.just_copied = true;
|
||||
cx.spawn(async move |this, cx| {
|
||||
cx.background_executor().timer(Duration::from_secs(2)).await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.just_copied = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}))
|
||||
CopyButton::new(message)
|
||||
.tooltip_label("Copy All Messages")
|
||||
.disabled(!has_messages)
|
||||
})
|
||||
.child(
|
||||
IconButton::new("clear_messages", IconName::Trash)
|
||||
|
||||
@@ -1167,10 +1167,6 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
|
||||
fn should_render_footer(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_favorites(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
use std::{any::Any, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agent_servers::{AgentServer, AgentServerDelegate};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use prompt_store::PromptStore;
|
||||
use settings::{LanguageModelSelection, Settings as _, update_settings_file};
|
||||
|
||||
use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
|
||||
@@ -71,6 +75,38 @@ impl AgentServer for NativeAgentServer {
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
|
||||
fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
|
||||
AgentSettings::get_global(cx).favorite_model_ids()
|
||||
}
|
||||
|
||||
fn toggle_favorite_model(
|
||||
&self,
|
||||
model_id: acp::ModelId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let selection = model_id_to_selection(&model_id);
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let agent = settings.agent.get_or_insert_default();
|
||||
if should_be_favorite {
|
||||
agent.add_favorite_model(selection.clone());
|
||||
} else {
|
||||
agent.remove_favorite_model(&selection);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
|
||||
fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
|
||||
let id = model_id.0.as_ref();
|
||||
let (provider, model) = id.split_once('/').unwrap_or(("", id));
|
||||
LanguageModelSelection {
|
||||
provider: provider.to_owned().into(),
|
||||
model: model.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -21,6 +21,7 @@ acp_tools.workspace = true
|
||||
acp_thread.workspace = true
|
||||
action_log.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
feature_flags.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
client.workspace = true
|
||||
|
||||
@@ -4,6 +4,7 @@ use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
|
||||
use futures::AsyncBufReadExt as _;
|
||||
use futures::io::BufReader;
|
||||
use project::Project;
|
||||
@@ -38,6 +39,7 @@ pub struct AcpConnection {
|
||||
agent_capabilities: acp::AgentCapabilities,
|
||||
default_mode: Option<acp::SessionModeId>,
|
||||
default_model: Option<acp::ModelId>,
|
||||
default_config_options: HashMap<String, String>,
|
||||
root_dir: PathBuf,
|
||||
// NB: Don't move this into the wait_task, since we need to ensure the process is
|
||||
// killed on drop (setting kill_on_drop on the command seems to not always work).
|
||||
@@ -47,11 +49,29 @@ pub struct AcpConnection {
|
||||
_stderr_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
struct ConfigOptions {
|
||||
config_options: Rc<RefCell<Vec<acp::SessionConfigOption>>>,
|
||||
tx: Rc<RefCell<watch::Sender<()>>>,
|
||||
rx: watch::Receiver<()>,
|
||||
}
|
||||
|
||||
impl ConfigOptions {
|
||||
fn new(config_options: Rc<RefCell<Vec<acp::SessionConfigOption>>>) -> Self {
|
||||
let (tx, rx) = watch::channel(());
|
||||
Self {
|
||||
config_options,
|
||||
tx: Rc::new(RefCell::new(tx)),
|
||||
rx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcpSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
suppress_abort_err: bool,
|
||||
models: Option<Rc<RefCell<acp::SessionModelState>>>,
|
||||
session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
|
||||
config_options: Option<ConfigOptions>,
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
@@ -60,6 +80,7 @@ pub async fn connect(
|
||||
root_dir: &Path,
|
||||
default_mode: Option<acp::SessionModeId>,
|
||||
default_model: Option<acp::ModelId>,
|
||||
default_config_options: HashMap<String, String>,
|
||||
is_remote: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Rc<dyn AgentConnection>> {
|
||||
@@ -69,6 +90,7 @@ pub async fn connect(
|
||||
root_dir,
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
is_remote,
|
||||
cx,
|
||||
)
|
||||
@@ -85,6 +107,7 @@ impl AcpConnection {
|
||||
root_dir: &Path,
|
||||
default_mode: Option<acp::SessionModeId>,
|
||||
default_model: Option<acp::ModelId>,
|
||||
default_config_options: HashMap<String, String>,
|
||||
is_remote: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Self> {
|
||||
@@ -217,6 +240,7 @@ impl AcpConnection {
|
||||
agent_capabilities: response.agent_capabilities,
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
_io_task: io_task,
|
||||
_wait_task: wait_task,
|
||||
_stderr_task: stderr_task,
|
||||
@@ -256,6 +280,7 @@ impl AgentConnection for AcpConnection {
|
||||
let sessions = self.sessions.clone();
|
||||
let default_mode = self.default_mode.clone();
|
||||
let default_model = self.default_model.clone();
|
||||
let default_config_options = self.default_config_options.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
let context_server_store = project.read(cx).context_server_store().read(cx);
|
||||
let mcp_servers = if project.read(cx).is_local() {
|
||||
@@ -322,8 +347,21 @@ impl AgentConnection for AcpConnection {
|
||||
}
|
||||
})?;
|
||||
|
||||
let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
|
||||
let models = response.models.map(|models| Rc::new(RefCell::new(models)));
|
||||
let use_config_options = cx.update(|cx| cx.has_flag::<AcpBetaFeatureFlag>())?;
|
||||
|
||||
// Config options take precedence over legacy modes/models
|
||||
let (modes, models, config_options) = if use_config_options && let Some(opts) = response.config_options {
|
||||
(
|
||||
None,
|
||||
None,
|
||||
Some(Rc::new(RefCell::new(opts))),
|
||||
)
|
||||
} else {
|
||||
// Fall back to legacy modes/models
|
||||
let modes = response.modes.map(|modes| Rc::new(RefCell::new(modes)));
|
||||
let models = response.models.map(|models| Rc::new(RefCell::new(models)));
|
||||
(modes, models, None)
|
||||
};
|
||||
|
||||
if let Some(default_mode) = default_mode {
|
||||
if let Some(modes) = modes.as_ref() {
|
||||
@@ -411,6 +449,92 @@ impl AgentConnection for AcpConnection {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(config_opts) = config_options.as_ref() {
|
||||
let defaults_to_apply: Vec<_> = {
|
||||
let config_opts_ref = config_opts.borrow();
|
||||
config_opts_ref
|
||||
.iter()
|
||||
.filter_map(|config_option| {
|
||||
let default_value = default_config_options.get(&*config_option.id.0)?;
|
||||
|
||||
let is_valid = match &config_option.kind {
|
||||
acp::SessionConfigKind::Select(select) => match &select.options {
|
||||
acp::SessionConfigSelectOptions::Ungrouped(options) => {
|
||||
options.iter().any(|opt| &*opt.value.0 == default_value.as_str())
|
||||
}
|
||||
acp::SessionConfigSelectOptions::Grouped(groups) => groups
|
||||
.iter()
|
||||
.any(|g| g.options.iter().any(|opt| &*opt.value.0 == default_value.as_str())),
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if is_valid {
|
||||
let initial_value = match &config_option.kind {
|
||||
acp::SessionConfigKind::Select(select) => {
|
||||
Some(select.current_value.clone())
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
Some((config_option.id.clone(), default_value.clone(), initial_value))
|
||||
} else {
|
||||
log::warn!(
|
||||
"`{}` is not a valid value for config option `{}` in {}",
|
||||
default_value,
|
||||
config_option.id.0,
|
||||
name
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
for (config_id, default_value, initial_value) in defaults_to_apply {
|
||||
cx.spawn({
|
||||
let default_value_id = acp::SessionConfigValueId::new(default_value.clone());
|
||||
let session_id = response.session_id.clone();
|
||||
let config_id_clone = config_id.clone();
|
||||
let config_opts = config_opts.clone();
|
||||
let conn = conn.clone();
|
||||
async move |_| {
|
||||
let result = conn
|
||||
.set_session_config_option(
|
||||
acp::SetSessionConfigOptionRequest::new(
|
||||
session_id,
|
||||
config_id_clone.clone(),
|
||||
default_value_id,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if result.is_none() {
|
||||
if let Some(initial) = initial_value {
|
||||
let mut opts = config_opts.borrow_mut();
|
||||
if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id_clone) {
|
||||
if let acp::SessionConfigKind::Select(select) =
|
||||
&mut opt.kind
|
||||
{
|
||||
select.current_value = initial;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let mut opts = config_opts.borrow_mut();
|
||||
if let Some(opt) = opts.iter_mut().find(|o| o.id == config_id) {
|
||||
if let acp::SessionConfigKind::Select(select) = &mut opt.kind {
|
||||
select.current_value = acp::SessionConfigValueId::new(default_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let session_id = response.session_id;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
|
||||
let thread = cx.new(|cx| {
|
||||
@@ -432,6 +556,7 @@ impl AgentConnection for AcpConnection {
|
||||
suppress_abort_err: false,
|
||||
session_modes: modes,
|
||||
models,
|
||||
config_options: config_options.map(|opts| ConfigOptions::new(opts))
|
||||
};
|
||||
sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
@@ -567,6 +692,25 @@ impl AgentConnection for AcpConnection {
|
||||
}
|
||||
}
|
||||
|
||||
fn session_config_options(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
_cx: &App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionConfigOptions>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let session = sessions.get(session_id)?;
|
||||
|
||||
let config_opts = session.config_options.as_ref()?;
|
||||
|
||||
Some(Rc::new(AcpSessionConfigOptions {
|
||||
session_id: session_id.clone(),
|
||||
connection: self.connection.clone(),
|
||||
state: config_opts.config_options.clone(),
|
||||
watch_tx: config_opts.tx.clone(),
|
||||
watch_rx: config_opts.rx.clone(),
|
||||
}) as _)
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
@@ -685,6 +829,49 @@ impl acp_thread::AgentModelSelector for AcpModelSelector {
|
||||
}
|
||||
}
|
||||
|
||||
struct AcpSessionConfigOptions {
|
||||
session_id: acp::SessionId,
|
||||
connection: Rc<acp::ClientSideConnection>,
|
||||
state: Rc<RefCell<Vec<acp::SessionConfigOption>>>,
|
||||
watch_tx: Rc<RefCell<watch::Sender<()>>>,
|
||||
watch_rx: watch::Receiver<()>,
|
||||
}
|
||||
|
||||
impl acp_thread::AgentSessionConfigOptions for AcpSessionConfigOptions {
|
||||
fn config_options(&self) -> Vec<acp::SessionConfigOption> {
|
||||
self.state.borrow().clone()
|
||||
}
|
||||
|
||||
fn set_config_option(
|
||||
&self,
|
||||
config_id: acp::SessionConfigId,
|
||||
value: acp::SessionConfigValueId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Vec<acp::SessionConfigOption>>> {
|
||||
let connection = self.connection.clone();
|
||||
let session_id = self.session_id.clone();
|
||||
let state = self.state.clone();
|
||||
|
||||
let watch_tx = self.watch_tx.clone();
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let response = connection
|
||||
.set_session_config_option(acp::SetSessionConfigOptionRequest::new(
|
||||
session_id, config_id, value,
|
||||
))
|
||||
.await?;
|
||||
|
||||
*state.borrow_mut() = response.config_options.clone();
|
||||
watch_tx.borrow_mut().send(()).ok();
|
||||
Ok(response.config_options)
|
||||
})
|
||||
}
|
||||
|
||||
fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
|
||||
Some(self.watch_rx.clone())
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientDelegate {
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
cx: AsyncApp,
|
||||
@@ -778,6 +965,21 @@ impl acp::Client for ClientDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
if let acp::SessionUpdate::ConfigOptionUpdate(acp::ConfigOptionUpdate {
|
||||
config_options,
|
||||
..
|
||||
}) = ¬ification.update
|
||||
{
|
||||
if let Some(opts) = &session.config_options {
|
||||
*opts.config_options.borrow_mut() = config_options.clone();
|
||||
opts.tx.borrow_mut().send(()).ok();
|
||||
} else {
|
||||
log::error!(
|
||||
"Got a `ConfigOptionUpdate` notification, but the agent didn't specify `config_options` during session setup."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clone so we can inspect meta both before and after handing off to the thread
|
||||
let update_clone = notification.update.clone();
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ pub mod e2e_tests;
|
||||
pub use claude::*;
|
||||
use client::ProxySettings;
|
||||
pub use codex::*;
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use custom::*;
|
||||
use fs::Fs;
|
||||
pub use gemini::*;
|
||||
@@ -56,9 +56,19 @@ impl AgentServerDelegate {
|
||||
pub trait AgentServer: Send {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> SharedString;
|
||||
fn default_mode(&self, _cx: &mut App) -> Option<agent_client_protocol::SessionModeId> {
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: Option<&Path>,
|
||||
delegate: AgentServerDelegate,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>>;
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||
|
||||
fn default_mode(&self, _cx: &App) -> Option<agent_client_protocol::SessionModeId> {
|
||||
None
|
||||
}
|
||||
|
||||
fn set_default_mode(
|
||||
&self,
|
||||
_mode_id: Option<agent_client_protocol::SessionModeId>,
|
||||
@@ -67,7 +77,7 @@ pub trait AgentServer: Send {
|
||||
) {
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &mut App) -> Option<agent_client_protocol::ModelId> {
|
||||
fn default_model(&self, _cx: &App) -> Option<agent_client_protocol::ModelId> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -79,14 +89,49 @@ pub trait AgentServer: Send {
|
||||
) {
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: Option<&Path>,
|
||||
delegate: AgentServerDelegate,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>>;
|
||||
fn favorite_model_ids(&self, _cx: &mut App) -> HashSet<agent_client_protocol::ModelId> {
|
||||
HashSet::default()
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||
fn default_config_option(&self, _config_id: &str, _cx: &App) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
fn set_default_config_option(
|
||||
&self,
|
||||
_config_id: &str,
|
||||
_value_id: Option<&str>,
|
||||
_fs: Arc<dyn Fs>,
|
||||
_cx: &mut App,
|
||||
) {
|
||||
}
|
||||
|
||||
fn favorite_config_option_value_ids(
|
||||
&self,
|
||||
_config_id: &agent_client_protocol::SessionConfigId,
|
||||
_cx: &mut App,
|
||||
) -> HashSet<agent_client_protocol::SessionConfigValueId> {
|
||||
HashSet::default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_config_option_value(
|
||||
&self,
|
||||
_config_id: agent_client_protocol::SessionConfigId,
|
||||
_value_id: agent_client_protocol::SessionConfigValueId,
|
||||
_should_be_favorite: bool,
|
||||
_fs: Arc<dyn Fs>,
|
||||
_cx: &App,
|
||||
) {
|
||||
}
|
||||
|
||||
fn toggle_favorite_model(
|
||||
&self,
|
||||
_model_id: agent_client_protocol::ModelId,
|
||||
_should_be_favorite: bool,
|
||||
_fs: Arc<dyn Fs>,
|
||||
_cx: &App,
|
||||
) {
|
||||
}
|
||||
}
|
||||
|
||||
impl dyn AgentServer {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use agent_client_protocol as acp;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use settings::{SettingsStore, update_settings_file};
|
||||
use std::path::Path;
|
||||
@@ -30,7 +31,7 @@ impl AgentServer for ClaudeCode {
|
||||
ui::IconName::AiClaude
|
||||
}
|
||||
|
||||
fn default_mode(&self, cx: &mut App) -> Option<acp::SessionModeId> {
|
||||
fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
});
|
||||
@@ -51,7 +52,7 @@ impl AgentServer for ClaudeCode {
|
||||
});
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
|
||||
fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
});
|
||||
@@ -72,6 +73,139 @@ impl AgentServer for ClaudeCode {
|
||||
});
|
||||
}
|
||||
|
||||
fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.map(|s| {
|
||||
s.favorite_models
|
||||
.iter()
|
||||
.map(|id| acp::ModelId::new(id.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_model(
|
||||
&self,
|
||||
model_id: acp::ModelId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let favorite_models = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.claude
|
||||
.get_or_insert_default()
|
||||
.favorite_models;
|
||||
|
||||
let model_id_str = model_id.to_string();
|
||||
if should_be_favorite {
|
||||
if !favorite_models.contains(&model_id_str) {
|
||||
favorite_models.push(model_id_str);
|
||||
}
|
||||
} else {
|
||||
favorite_models.retain(|id| id != &model_id_str);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.default_config_options.get(config_id).cloned())
|
||||
}
|
||||
|
||||
fn set_default_config_option(
|
||||
&self,
|
||||
config_id: &str,
|
||||
value_id: Option<&str>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.map(|s| s.to_string());
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let config_options = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.claude
|
||||
.get_or_insert_default()
|
||||
.default_config_options;
|
||||
|
||||
if let Some(value) = value_id.clone() {
|
||||
config_options.insert(config_id.clone(), value);
|
||||
} else {
|
||||
config_options.remove(&config_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn favorite_config_option_value_ids(
|
||||
&self,
|
||||
config_id: &acp::SessionConfigId,
|
||||
cx: &mut App,
|
||||
) -> HashSet<acp::SessionConfigValueId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).claude.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.favorite_config_option_values.get(config_id.0.as_ref()))
|
||||
.map(|values| {
|
||||
values
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(acp::SessionConfigValueId::new)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_config_option_value(
|
||||
&self,
|
||||
config_id: acp::SessionConfigId,
|
||||
value_id: acp::SessionConfigValueId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.to_string();
|
||||
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let favorites = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.claude
|
||||
.get_or_insert_default()
|
||||
.favorite_config_option_values;
|
||||
|
||||
let entry = favorites.entry(config_id.clone()).or_insert_with(Vec::new);
|
||||
|
||||
if should_be_favorite {
|
||||
if !entry.iter().any(|v| v == &value_id) {
|
||||
entry.push(value_id.clone());
|
||||
}
|
||||
} else {
|
||||
entry.retain(|v| v != &value_id);
|
||||
if entry.is_empty() {
|
||||
favorites.remove(&config_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: Option<&Path>,
|
||||
@@ -85,6 +219,14 @@ impl AgentServer for ClaudeCode {
|
||||
let extra_env = load_proxy_env(cx);
|
||||
let default_mode = self.default_mode(cx);
|
||||
let default_model = self.default_model(cx);
|
||||
let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.claude
|
||||
.as_ref()
|
||||
.map(|s| s.default_config_options.clone())
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (command, root_dir, login) = store
|
||||
@@ -107,6 +249,7 @@ impl AgentServer for ClaudeCode {
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
is_remote,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::{any::Any, path::Path};
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{App, AppContext as _, SharedString, Task};
|
||||
use project::agent_server_store::{AllAgentServersSettings, CODEX_NAME};
|
||||
@@ -31,7 +32,7 @@ impl AgentServer for Codex {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn default_mode(&self, cx: &mut App) -> Option<acp::SessionModeId> {
|
||||
fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
});
|
||||
@@ -52,7 +53,7 @@ impl AgentServer for Codex {
|
||||
});
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
|
||||
fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
});
|
||||
@@ -73,6 +74,139 @@ impl AgentServer for Codex {
|
||||
});
|
||||
}
|
||||
|
||||
fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.map(|s| {
|
||||
s.favorite_models
|
||||
.iter()
|
||||
.map(|id| acp::ModelId::new(id.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_model(
|
||||
&self,
|
||||
model_id: acp::ModelId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let favorite_models = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.codex
|
||||
.get_or_insert_default()
|
||||
.favorite_models;
|
||||
|
||||
let model_id_str = model_id.to_string();
|
||||
if should_be_favorite {
|
||||
if !favorite_models.contains(&model_id_str) {
|
||||
favorite_models.push(model_id_str);
|
||||
}
|
||||
} else {
|
||||
favorite_models.retain(|id| id != &model_id_str);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.default_config_options.get(config_id).cloned())
|
||||
}
|
||||
|
||||
fn set_default_config_option(
|
||||
&self,
|
||||
config_id: &str,
|
||||
value_id: Option<&str>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.map(|s| s.to_string());
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let config_options = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.codex
|
||||
.get_or_insert_default()
|
||||
.default_config_options;
|
||||
|
||||
if let Some(value) = value_id.clone() {
|
||||
config_options.insert(config_id.clone(), value);
|
||||
} else {
|
||||
config_options.remove(&config_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn favorite_config_option_value_ids(
|
||||
&self,
|
||||
config_id: &acp::SessionConfigId,
|
||||
cx: &mut App,
|
||||
) -> HashSet<acp::SessionConfigValueId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.favorite_config_option_values.get(config_id.0.as_ref()))
|
||||
.map(|values| {
|
||||
values
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(acp::SessionConfigValueId::new)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_config_option_value(
|
||||
&self,
|
||||
config_id: acp::SessionConfigId,
|
||||
value_id: acp::SessionConfigValueId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.to_string();
|
||||
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let favorites = &mut settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.codex
|
||||
.get_or_insert_default()
|
||||
.favorite_config_option_values;
|
||||
|
||||
let entry = favorites.entry(config_id.clone()).or_insert_with(Vec::new);
|
||||
|
||||
if should_be_favorite {
|
||||
if !entry.iter().any(|v| v == &value_id) {
|
||||
entry.push(value_id.clone());
|
||||
}
|
||||
} else {
|
||||
entry.retain(|v| v != &value_id);
|
||||
if entry.is_empty() {
|
||||
favorites.remove(&config_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: Option<&Path>,
|
||||
@@ -86,6 +220,14 @@ impl AgentServer for Codex {
|
||||
let extra_env = load_proxy_env(cx);
|
||||
let default_mode = self.default_mode(cx);
|
||||
let default_model = self.default_model(cx);
|
||||
let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.codex
|
||||
.as_ref()
|
||||
.map(|s| s.default_config_options.clone())
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (command, root_dir, login) = store
|
||||
@@ -109,6 +251,7 @@ impl AgentServer for Codex {
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
is_remote,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{App, AppContext as _, SharedString, Task};
|
||||
use project::agent_server_store::{AllAgentServersSettings, ExternalAgentServerName};
|
||||
@@ -29,7 +30,7 @@ impl AgentServer for CustomAgentServer {
|
||||
IconName::Terminal
|
||||
}
|
||||
|
||||
fn default_mode(&self, cx: &mut App) -> Option<acp::SessionModeId> {
|
||||
fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
@@ -43,6 +44,86 @@ impl AgentServer for CustomAgentServer {
|
||||
.and_then(|s| s.default_mode().map(acp::SessionModeId::new))
|
||||
}
|
||||
|
||||
fn favorite_config_option_value_ids(
|
||||
&self,
|
||||
config_id: &acp::SessionConfigId,
|
||||
cx: &mut App,
|
||||
) -> HashSet<acp::SessionConfigValueId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.custom
|
||||
.get(&self.name())
|
||||
.cloned()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
|
||||
.map(|values| {
|
||||
values
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(acp::SessionConfigValueId::new)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_config_option_value(
|
||||
&self,
|
||||
config_id: acp::SessionConfigId,
|
||||
value_id: acp::SessionConfigValueId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let name = self.name();
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.to_string();
|
||||
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let settings = settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.custom
|
||||
.entry(name.clone())
|
||||
.or_insert_with(|| settings::CustomAgentServerSettings::Extension {
|
||||
default_model: None,
|
||||
default_mode: None,
|
||||
favorite_models: Vec::new(),
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
});
|
||||
|
||||
match settings {
|
||||
settings::CustomAgentServerSettings::Custom {
|
||||
favorite_config_option_values,
|
||||
..
|
||||
}
|
||||
| settings::CustomAgentServerSettings::Extension {
|
||||
favorite_config_option_values,
|
||||
..
|
||||
} => {
|
||||
let entry = favorite_config_option_values
|
||||
.entry(config_id.clone())
|
||||
.or_insert_with(Vec::new);
|
||||
|
||||
if should_be_favorite {
|
||||
if !entry.iter().any(|v| v == &value_id) {
|
||||
entry.push(value_id.clone());
|
||||
}
|
||||
} else {
|
||||
entry.retain(|v| v != &value_id);
|
||||
if entry.is_empty() {
|
||||
favorite_config_option_values.remove(&config_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
|
||||
let name = self.name();
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
@@ -54,6 +135,9 @@ impl AgentServer for CustomAgentServer {
|
||||
.or_insert_with(|| settings::CustomAgentServerSettings::Extension {
|
||||
default_model: None,
|
||||
default_mode: None,
|
||||
favorite_models: Vec::new(),
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
});
|
||||
|
||||
match settings {
|
||||
@@ -65,7 +149,7 @@ impl AgentServer for CustomAgentServer {
|
||||
});
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &mut App) -> Option<acp::ModelId> {
|
||||
fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
@@ -90,6 +174,9 @@ impl AgentServer for CustomAgentServer {
|
||||
.or_insert_with(|| settings::CustomAgentServerSettings::Extension {
|
||||
default_model: None,
|
||||
default_mode: None,
|
||||
favorite_models: Vec::new(),
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
});
|
||||
|
||||
match settings {
|
||||
@@ -101,6 +188,125 @@ impl AgentServer for CustomAgentServer {
|
||||
});
|
||||
}
|
||||
|
||||
fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.custom
|
||||
.get(&self.name())
|
||||
.cloned()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.map(|s| {
|
||||
s.favorite_models()
|
||||
.iter()
|
||||
.map(|id| acp::ModelId::new(id.clone()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn toggle_favorite_model(
|
||||
&self,
|
||||
model_id: acp::ModelId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let name = self.name();
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let settings = settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.custom
|
||||
.entry(name.clone())
|
||||
.or_insert_with(|| settings::CustomAgentServerSettings::Extension {
|
||||
default_model: None,
|
||||
default_mode: None,
|
||||
favorite_models: Vec::new(),
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
});
|
||||
|
||||
let favorite_models = match settings {
|
||||
settings::CustomAgentServerSettings::Custom {
|
||||
favorite_models, ..
|
||||
}
|
||||
| settings::CustomAgentServerSettings::Extension {
|
||||
favorite_models, ..
|
||||
} => favorite_models,
|
||||
};
|
||||
|
||||
let model_id_str = model_id.to_string();
|
||||
if should_be_favorite {
|
||||
if !favorite_models.contains(&model_id_str) {
|
||||
favorite_models.push(model_id_str);
|
||||
}
|
||||
} else {
|
||||
favorite_models.retain(|id| id != &model_id_str);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.custom
|
||||
.get(&self.name())
|
||||
.cloned()
|
||||
});
|
||||
|
||||
settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
|
||||
}
|
||||
|
||||
fn set_default_config_option(
|
||||
&self,
|
||||
config_id: &str,
|
||||
value_id: Option<&str>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let name = self.name();
|
||||
let config_id = config_id.to_string();
|
||||
let value_id = value_id.map(|s| s.to_string());
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let settings = settings
|
||||
.agent_servers
|
||||
.get_or_insert_default()
|
||||
.custom
|
||||
.entry(name.clone())
|
||||
.or_insert_with(|| settings::CustomAgentServerSettings::Extension {
|
||||
default_model: None,
|
||||
default_mode: None,
|
||||
favorite_models: Vec::new(),
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
});
|
||||
|
||||
match settings {
|
||||
settings::CustomAgentServerSettings::Custom {
|
||||
default_config_options,
|
||||
..
|
||||
}
|
||||
| settings::CustomAgentServerSettings::Extension {
|
||||
default_config_options,
|
||||
..
|
||||
} => {
|
||||
if let Some(value) = value_id.clone() {
|
||||
default_config_options.insert(config_id.clone(), value);
|
||||
} else {
|
||||
default_config_options.remove(&config_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
root_dir: Option<&Path>,
|
||||
@@ -112,6 +318,23 @@ impl AgentServer for CustomAgentServer {
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let default_mode = self.default_mode(cx);
|
||||
let default_model = self.default_model(cx);
|
||||
let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.custom
|
||||
.get(&self.name())
|
||||
.map(|s| match s {
|
||||
project::agent_server_store::CustomAgentServerSettings::Custom {
|
||||
default_config_options,
|
||||
..
|
||||
}
|
||||
| project::agent_server_store::CustomAgentServerSettings::Extension {
|
||||
default_config_options,
|
||||
..
|
||||
} => default_config_options.clone(),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
});
|
||||
let store = delegate.store.downgrade();
|
||||
let extra_env = load_proxy_env(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
@@ -137,6 +360,7 @@ impl AgentServer for CustomAgentServer {
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
is_remote,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -455,20 +455,12 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
project::agent_server_store::AllAgentServersSettings {
|
||||
claude: Some(BuiltinAgentServerSettings {
|
||||
path: Some("claude-code-acp".into()),
|
||||
args: None,
|
||||
env: None,
|
||||
ignore_system_version: None,
|
||||
default_mode: None,
|
||||
default_model: None,
|
||||
..Default::default()
|
||||
}),
|
||||
gemini: Some(crate::gemini::tests::local_command().into()),
|
||||
codex: Some(BuiltinAgentServerSettings {
|
||||
path: Some("codex-acp".into()),
|
||||
args: None,
|
||||
env: None,
|
||||
ignore_system_version: None,
|
||||
default_mode: None,
|
||||
default_model: None,
|
||||
..Default::default()
|
||||
}),
|
||||
custom: collections::HashMap::default(),
|
||||
},
|
||||
|
||||
@@ -4,9 +4,10 @@ use std::{any::Any, path::Path};
|
||||
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, SharedString, Task};
|
||||
use gpui::{App, AppContext as _, SharedString, Task};
|
||||
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||
use project::agent_server_store::GEMINI_NAME;
|
||||
use project::agent_server_store::{AllAgentServersSettings, GEMINI_NAME};
|
||||
use settings::SettingsStore;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Gemini;
|
||||
@@ -33,6 +34,14 @@ impl AgentServer for Gemini {
|
||||
let mut extra_env = load_proxy_env(cx);
|
||||
let default_mode = self.default_mode(cx);
|
||||
let default_model = self.default_model(cx);
|
||||
let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings
|
||||
.get::<AllAgentServersSettings>(None)
|
||||
.gemini
|
||||
.as_ref()
|
||||
.map(|s| s.default_config_options.clone())
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
|
||||
@@ -65,6 +74,7 @@ impl AgentServer for Gemini {
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
default_model,
|
||||
default_config_options,
|
||||
is_remote,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod config_options;
|
||||
mod entry_view_state;
|
||||
mod message_editor;
|
||||
mod mode_selector;
|
||||
|
||||
772
crates/agent_ui/src/acp/config_options.rs
Normal file
772
crates/agent_ui/src/acp/config_options.rs
Normal file
@@ -0,0 +1,772 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::AgentSessionConfigOptions;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_servers::AgentServer;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
use gpui::{
|
||||
BackgroundExecutor, Context, DismissEvent, Entity, Subscription, Task, Window, prelude::*,
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::popover_menu::PickerPopoverMenu;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use settings::SettingsStore;
|
||||
use ui::{
|
||||
ElevationIndex, IconButton, ListItem, ListItemSpacing, PopoverMenuHandle, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::ui::HoldForDefault;
|
||||
|
||||
const PICKER_THRESHOLD: usize = 5;
|
||||
|
||||
pub struct ConfigOptionsView {
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
selectors: Vec<Entity<ConfigOptionSelector>>,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
config_option_ids: Vec<acp::SessionConfigId>,
|
||||
_refresh_task: Task<()>,
|
||||
}
|
||||
|
||||
impl ConfigOptionsView {
|
||||
pub fn new(
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let selectors = Self::build_selectors(&config_options, &agent_server, &fs, window, cx);
|
||||
let config_option_ids = Self::config_option_ids(&config_options);
|
||||
|
||||
let rx = config_options.watch(cx);
|
||||
let refresh_task = cx.spawn_in(window, async move |this, cx| {
|
||||
if let Some(mut rx) = rx {
|
||||
while let Ok(()) = rx.recv().await {
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.refresh_selectors_if_needed(window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
config_options,
|
||||
selectors,
|
||||
agent_server,
|
||||
fs,
|
||||
config_option_ids,
|
||||
_refresh_task: refresh_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn config_option_ids(
|
||||
config_options: &Rc<dyn AgentSessionConfigOptions>,
|
||||
) -> Vec<acp::SessionConfigId> {
|
||||
config_options
|
||||
.config_options()
|
||||
.into_iter()
|
||||
.map(|option| option.id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn refresh_selectors_if_needed(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let current_ids = Self::config_option_ids(&self.config_options);
|
||||
if current_ids != self.config_option_ids {
|
||||
self.config_option_ids = current_ids;
|
||||
self.rebuild_selectors(window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn rebuild_selectors(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.selectors = Self::build_selectors(
|
||||
&self.config_options,
|
||||
&self.agent_server,
|
||||
&self.fs,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn build_selectors(
|
||||
config_options: &Rc<dyn AgentSessionConfigOptions>,
|
||||
agent_server: &Rc<dyn AgentServer>,
|
||||
fs: &Arc<dyn Fs>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Vec<Entity<ConfigOptionSelector>> {
|
||||
config_options
|
||||
.config_options()
|
||||
.into_iter()
|
||||
.map(|option| {
|
||||
let config_options = config_options.clone();
|
||||
let agent_server = agent_server.clone();
|
||||
let fs = fs.clone();
|
||||
cx.new(|cx| {
|
||||
ConfigOptionSelector::new(
|
||||
config_options,
|
||||
option.id.clone(),
|
||||
agent_server,
|
||||
fs,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigOptionsView {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
if self.selectors.is_empty() {
|
||||
return div().into_any_element();
|
||||
}
|
||||
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.children(self.selectors.iter().cloned())
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigOptionSelector {
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: acp::SessionConfigId,
|
||||
picker_handle: PopoverMenuHandle<Picker<ConfigOptionPickerDelegate>>,
|
||||
picker: Entity<Picker<ConfigOptionPickerDelegate>>,
|
||||
setting_value: bool,
|
||||
}
|
||||
|
||||
impl ConfigOptionSelector {
|
||||
pub fn new(
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: acp::SessionConfigId,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let option_count = config_options
|
||||
.config_options()
|
||||
.iter()
|
||||
.find(|opt| opt.id == config_id)
|
||||
.map(count_config_options)
|
||||
.unwrap_or(0);
|
||||
|
||||
let is_searchable = option_count >= PICKER_THRESHOLD;
|
||||
|
||||
let picker = {
|
||||
let config_options = config_options.clone();
|
||||
let config_id = config_id.clone();
|
||||
let agent_server = agent_server.clone();
|
||||
let fs = fs.clone();
|
||||
cx.new(move |picker_cx| {
|
||||
let delegate = ConfigOptionPickerDelegate::new(
|
||||
config_options,
|
||||
config_id,
|
||||
agent_server,
|
||||
fs,
|
||||
window,
|
||||
picker_cx,
|
||||
);
|
||||
|
||||
if is_searchable {
|
||||
Picker::list(delegate, window, picker_cx)
|
||||
} else {
|
||||
Picker::nonsearchable_list(delegate, window, picker_cx)
|
||||
}
|
||||
.show_scrollbar(true)
|
||||
.width(rems(20.))
|
||||
.max_height(Some(rems(20.).into()))
|
||||
})
|
||||
};
|
||||
|
||||
Self {
|
||||
config_options,
|
||||
config_id,
|
||||
picker_handle: PopoverMenuHandle::default(),
|
||||
picker,
|
||||
setting_value: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_option(&self) -> Option<acp::SessionConfigOption> {
|
||||
self.config_options
|
||||
.config_options()
|
||||
.into_iter()
|
||||
.find(|opt| opt.id == self.config_id)
|
||||
}
|
||||
|
||||
fn current_value_name(&self) -> String {
|
||||
let Some(option) = self.current_option() else {
|
||||
return "Unknown".to_string();
|
||||
};
|
||||
|
||||
match &option.kind {
|
||||
acp::SessionConfigKind::Select(select) => {
|
||||
find_option_name(&select.options, &select.current_value)
|
||||
.unwrap_or_else(|| "Unknown".to_string())
|
||||
}
|
||||
_ => "Unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_trigger_button(&self, _window: &mut Window, _cx: &mut Context<Self>) -> Button {
|
||||
let Some(option) = self.current_option() else {
|
||||
return Button::new("config-option-trigger", "Unknown")
|
||||
.label_size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.disabled(true);
|
||||
};
|
||||
|
||||
let icon = if self.picker_handle.is_deployed() {
|
||||
IconName::ChevronUp
|
||||
} else {
|
||||
IconName::ChevronDown
|
||||
};
|
||||
|
||||
Button::new(
|
||||
ElementId::Name(format!("config-option-{}", option.id.0).into()),
|
||||
self.current_value_name(),
|
||||
)
|
||||
.label_size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.icon(icon)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_position(IconPosition::End)
|
||||
.icon_color(Color::Muted)
|
||||
.disabled(self.setting_value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigOptionSelector {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let Some(option) = self.current_option() else {
|
||||
return div().into_any_element();
|
||||
};
|
||||
|
||||
let trigger_button = self.render_trigger_button(window, cx);
|
||||
|
||||
let option_name = option.name.clone();
|
||||
let option_description: Option<SharedString> = option.description.map(Into::into);
|
||||
|
||||
let tooltip = Tooltip::element(move |_window, _cx| {
|
||||
let mut content = v_flex().gap_1().child(Label::new(option_name.clone()));
|
||||
if let Some(desc) = option_description.as_ref() {
|
||||
content = content.child(
|
||||
Label::new(desc.clone())
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
);
|
||||
}
|
||||
content.into_any()
|
||||
});
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.picker.clone(),
|
||||
trigger_button,
|
||||
tooltip,
|
||||
gpui::Corner::BottomRight,
|
||||
cx,
|
||||
)
|
||||
.with_handle(self.picker_handle.clone())
|
||||
.render(window, cx)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ConfigOptionPickerEntry {
|
||||
Separator(SharedString),
|
||||
Option(ConfigOptionValue),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConfigOptionValue {
|
||||
value: acp::SessionConfigValueId,
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
group: Option<String>,
|
||||
}
|
||||
|
||||
struct ConfigOptionPickerDelegate {
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: acp::SessionConfigId,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
filtered_entries: Vec<ConfigOptionPickerEntry>,
|
||||
all_options: Vec<ConfigOptionValue>,
|
||||
selected_index: usize,
|
||||
selected_description: Option<(usize, SharedString, bool)>,
|
||||
favorites: HashSet<acp::SessionConfigValueId>,
|
||||
_settings_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl ConfigOptionPickerDelegate {
|
||||
fn new(
|
||||
config_options: Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: acp::SessionConfigId,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Self {
|
||||
let favorites = agent_server.favorite_config_option_value_ids(&config_id, cx);
|
||||
|
||||
let all_options = extract_options(&config_options, &config_id);
|
||||
let filtered_entries = options_to_picker_entries(&all_options, &favorites);
|
||||
|
||||
let current_value = get_current_value(&config_options, &config_id);
|
||||
let selected_index = current_value
|
||||
.and_then(|current| {
|
||||
filtered_entries.iter().position(|entry| {
|
||||
matches!(entry, ConfigOptionPickerEntry::Option(opt) if opt.value == current)
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
let agent_server_for_subscription = agent_server.clone();
|
||||
let config_id_for_subscription = config_id.clone();
|
||||
let settings_subscription =
|
||||
cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
|
||||
let new_favorites = agent_server_for_subscription
|
||||
.favorite_config_option_value_ids(&config_id_for_subscription, cx);
|
||||
if new_favorites != picker.delegate.favorites {
|
||||
picker.delegate.favorites = new_favorites;
|
||||
picker.refresh(window, cx);
|
||||
}
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
|
||||
Self {
|
||||
config_options,
|
||||
config_id,
|
||||
agent_server,
|
||||
fs,
|
||||
filtered_entries,
|
||||
all_options,
|
||||
selected_index,
|
||||
selected_description: None,
|
||||
favorites,
|
||||
_settings_subscription: settings_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
fn current_value(&self) -> Option<acp::SessionConfigValueId> {
|
||||
get_current_value(&self.config_options, &self.config_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for ConfigOptionPickerDelegate {
|
||||
type ListItem = AnyElement;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.filtered_entries.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn can_select(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) -> bool {
|
||||
match self.filtered_entries.get(ix) {
|
||||
Some(ConfigOptionPickerEntry::Option(_)) => true,
|
||||
Some(ConfigOptionPickerEntry::Separator(_)) | None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
"Select an option…".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let all_options = self.all_options.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let filtered_options = match this
|
||||
.read_with(cx, |_, cx| {
|
||||
if query.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((all_options.clone(), query.clone(), cx.background_executor().clone()))
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
Some((options, q, executor)) => fuzzy_search_options(options, &q, executor).await,
|
||||
None => all_options,
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.filtered_entries =
|
||||
options_to_picker_entries(&filtered_options, &this.delegate.favorites);
|
||||
|
||||
let current_value = this.delegate.current_value();
|
||||
let new_index = current_value
|
||||
.and_then(|current| {
|
||||
this.delegate.filtered_entries.iter().position(|entry| {
|
||||
matches!(entry, ConfigOptionPickerEntry::Option(opt) if opt.value == current)
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if let Some(ConfigOptionPickerEntry::Option(option)) =
|
||||
self.filtered_entries.get(self.selected_index)
|
||||
{
|
||||
if window.modifiers().secondary() {
|
||||
let default_value = self
|
||||
.agent_server
|
||||
.default_config_option(self.config_id.0.as_ref(), cx);
|
||||
let is_default = default_value.as_deref() == Some(&*option.value.0);
|
||||
|
||||
self.agent_server.set_default_config_option(
|
||||
self.config_id.0.as_ref(),
|
||||
if is_default {
|
||||
None
|
||||
} else {
|
||||
Some(option.value.0.as_ref())
|
||||
},
|
||||
self.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
let task = self.config_options.set_config_option(
|
||||
self.config_id.clone(),
|
||||
option.value.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |_, _| {
|
||||
if let Err(err) = task.await {
|
||||
log::error!("Failed to set config option: {:?}", err);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
cx.defer_in(window, |picker, window, cx| {
|
||||
picker.set_query("", window, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
match self.filtered_entries.get(ix)? {
|
||||
ConfigOptionPickerEntry::Separator(title) => Some(
|
||||
div()
|
||||
.when(ix > 0, |this| this.mt_1())
|
||||
.child(
|
||||
div()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.text_xs()
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.child(title.clone()),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
ConfigOptionPickerEntry::Option(option) => {
|
||||
let current_value = self.current_value();
|
||||
let is_selected = current_value.as_ref() == Some(&option.value);
|
||||
|
||||
let default_value = self
|
||||
.agent_server
|
||||
.default_config_option(self.config_id.0.as_ref(), cx);
|
||||
let is_default = default_value.as_deref() == Some(&*option.value.0);
|
||||
|
||||
let is_favorite = self.favorites.contains(&option.value);
|
||||
|
||||
let option_name = option.name.clone();
|
||||
let description = option.description.clone();
|
||||
|
||||
Some(
|
||||
div()
|
||||
.id(("config-option-picker-item", ix))
|
||||
.when_some(description, |this, desc| {
|
||||
let desc: SharedString = desc.into();
|
||||
this.on_hover(cx.listener(move |menu, hovered, _, cx| {
|
||||
if *hovered {
|
||||
menu.delegate.selected_description =
|
||||
Some((ix, desc.clone(), is_default));
|
||||
} else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix)
|
||||
{
|
||||
menu.delegate.selected_description = None;
|
||||
}
|
||||
cx.notify();
|
||||
}))
|
||||
})
|
||||
.child(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.child(h_flex().w_full().child(Label::new(option_name).truncate()))
|
||||
.end_slot(div().pr_2().when(is_selected, |this| {
|
||||
this.child(Icon::new(IconName::Check).color(Color::Accent))
|
||||
}))
|
||||
.end_hover_slot(div().pr_1p5().child({
|
||||
let (icon, color, tooltip) = if is_favorite {
|
||||
(IconName::StarFilled, Color::Accent, "Unfavorite")
|
||||
} else {
|
||||
(IconName::Star, Color::Default, "Favorite")
|
||||
};
|
||||
|
||||
let config_id = self.config_id.clone();
|
||||
let value_id = option.value.clone();
|
||||
let agent_server = self.agent_server.clone();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
IconButton::new(("toggle-favorite-config-option", ix), icon)
|
||||
.layer(ElevationIndex::ElevatedSurface)
|
||||
.icon_color(color)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text(tooltip))
|
||||
.on_click(move |_, _, cx| {
|
||||
agent_server.toggle_favorite_config_option_value(
|
||||
config_id.clone(),
|
||||
value_id.clone(),
|
||||
!is_favorite,
|
||||
fs.clone(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
})),
|
||||
)
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn documentation_aside(
|
||||
&self,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<ui::DocumentationAside> {
|
||||
self.selected_description
|
||||
.as_ref()
|
||||
.map(|(_, description, is_default)| {
|
||||
let description = description.clone();
|
||||
let is_default = *is_default;
|
||||
|
||||
ui::DocumentationAside::new(
|
||||
ui::DocumentationSide::Left,
|
||||
ui::DocumentationEdge::Top,
|
||||
Rc::new(move |_| {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(Label::new(description.clone()))
|
||||
.child(HoldForDefault::new(is_default))
|
||||
.into_any_element()
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_options(
|
||||
config_options: &Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: &acp::SessionConfigId,
|
||||
) -> Vec<ConfigOptionValue> {
|
||||
let Some(option) = config_options
|
||||
.config_options()
|
||||
.into_iter()
|
||||
.find(|opt| &opt.id == config_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
match &option.kind {
|
||||
acp::SessionConfigKind::Select(select) => match &select.options {
|
||||
acp::SessionConfigSelectOptions::Ungrouped(options) => options
|
||||
.iter()
|
||||
.map(|opt| ConfigOptionValue {
|
||||
value: opt.value.clone(),
|
||||
name: opt.name.clone(),
|
||||
description: opt.description.clone(),
|
||||
group: None,
|
||||
})
|
||||
.collect(),
|
||||
acp::SessionConfigSelectOptions::Grouped(groups) => groups
|
||||
.iter()
|
||||
.flat_map(|group| {
|
||||
group.options.iter().map(|opt| ConfigOptionValue {
|
||||
value: opt.value.clone(),
|
||||
name: opt.name.clone(),
|
||||
description: opt.description.clone(),
|
||||
group: Some(group.name.clone()),
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
_ => Vec::new(),
|
||||
},
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_current_value(
|
||||
config_options: &Rc<dyn AgentSessionConfigOptions>,
|
||||
config_id: &acp::SessionConfigId,
|
||||
) -> Option<acp::SessionConfigValueId> {
|
||||
config_options
|
||||
.config_options()
|
||||
.into_iter()
|
||||
.find(|opt| &opt.id == config_id)
|
||||
.and_then(|opt| match &opt.kind {
|
||||
acp::SessionConfigKind::Select(select) => Some(select.current_value.clone()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn options_to_picker_entries(
|
||||
options: &[ConfigOptionValue],
|
||||
favorites: &HashSet<acp::SessionConfigValueId>,
|
||||
) -> Vec<ConfigOptionPickerEntry> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
let mut favorite_options = Vec::new();
|
||||
|
||||
for option in options {
|
||||
if favorites.contains(&option.value) {
|
||||
favorite_options.push(option.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !favorite_options.is_empty() {
|
||||
entries.push(ConfigOptionPickerEntry::Separator("Favorites".into()));
|
||||
for option in favorite_options {
|
||||
entries.push(ConfigOptionPickerEntry::Option(option));
|
||||
}
|
||||
|
||||
// If the remaining list would start ungrouped (group == None), insert a separator so
|
||||
// Favorites doesn't visually run into the main list.
|
||||
if let Some(option) = options.first()
|
||||
&& option.group.is_none()
|
||||
{
|
||||
entries.push(ConfigOptionPickerEntry::Separator("All Options".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let mut current_group: Option<String> = None;
|
||||
for option in options {
|
||||
if option.group != current_group {
|
||||
if let Some(group_name) = &option.group {
|
||||
entries.push(ConfigOptionPickerEntry::Separator(
|
||||
group_name.clone().into(),
|
||||
));
|
||||
}
|
||||
current_group = option.group.clone();
|
||||
}
|
||||
entries.push(ConfigOptionPickerEntry::Option(option.clone()));
|
||||
}
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
async fn fuzzy_search_options(
|
||||
options: Vec<ConfigOptionValue>,
|
||||
query: &str,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Vec<ConfigOptionValue> {
|
||||
let candidates = options
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, opt)| StringMatchCandidate::new(ix, &opt.name))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
query,
|
||||
false,
|
||||
true,
|
||||
100,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
matches.sort_unstable_by_key(|mat| {
|
||||
let candidate = &candidates[mat.candidate_id];
|
||||
(Reverse(OrderedFloat(mat.score)), candidate.id)
|
||||
});
|
||||
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|mat| options[mat.candidate_id].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn find_option_name(
|
||||
options: &acp::SessionConfigSelectOptions,
|
||||
value_id: &acp::SessionConfigValueId,
|
||||
) -> Option<String> {
|
||||
match options {
|
||||
acp::SessionConfigSelectOptions::Ungrouped(opts) => opts
|
||||
.iter()
|
||||
.find(|o| &o.value == value_id)
|
||||
.map(|o| o.name.clone()),
|
||||
acp::SessionConfigSelectOptions::Grouped(groups) => groups.iter().find_map(|group| {
|
||||
group
|
||||
.options
|
||||
.iter()
|
||||
.find(|o| &o.value == value_id)
|
||||
.map(|o| o.name.clone())
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn count_config_options(option: &acp::SessionConfigOption) -> usize {
|
||||
match &option.kind {
|
||||
acp::SessionConfigKind::Select(select) => match &select.options {
|
||||
acp::SessionConfigSelectOptions::Ungrouped(options) => options.len(),
|
||||
acp::SessionConfigSelectOptions::Grouped(groups) => {
|
||||
groups.iter().map(|g| g.options.len()).sum()
|
||||
}
|
||||
_ => 0,
|
||||
},
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ use rope::Point;
|
||||
use settings::Settings;
|
||||
use std::{cell::RefCell, fmt::Write, rc::Rc, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use ui::{ContextMenu, prelude::*};
|
||||
use util::{ResultExt, debug_panic};
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::{Chat, PasteRaw};
|
||||
@@ -132,6 +132,21 @@ impl MessageEditor {
|
||||
placement: Some(ContextMenuPlacement::Above),
|
||||
});
|
||||
editor.register_addon(MessageEditorAddon::new());
|
||||
|
||||
editor.set_custom_context_menu(|editor, _point, window, cx| {
|
||||
let has_selection = editor.has_non_empty_selection(&editor.display_snapshot(cx));
|
||||
|
||||
Some(ContextMenu::build(window, cx, |menu, _, _| {
|
||||
menu.action("Cut", Box::new(editor::actions::Cut))
|
||||
.action_disabled_when(
|
||||
!has_selection,
|
||||
"Copy",
|
||||
Box::new(editor::actions::Copy),
|
||||
)
|
||||
.action("Paste", Box::new(editor::actions::Paste))
|
||||
}))
|
||||
});
|
||||
|
||||
editor
|
||||
});
|
||||
let mention_set =
|
||||
|
||||
@@ -3,19 +3,19 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_client_protocol::ModelId;
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::Result;
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::FutureExt;
|
||||
use fuzzy::{StringMatchCandidate, match_strings};
|
||||
use gpui::{
|
||||
Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
|
||||
Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
|
||||
WeakEntity,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use settings::Settings;
|
||||
use settings::SettingsStore;
|
||||
use ui::{DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, prelude::*};
|
||||
use util::ResultExt;
|
||||
use zed_actions::agent::OpenSettings;
|
||||
@@ -54,7 +54,9 @@ pub struct AcpModelPickerDelegate {
|
||||
selected_index: usize,
|
||||
selected_description: Option<(usize, SharedString, bool)>,
|
||||
selected_model: Option<AgentModelInfo>,
|
||||
favorites: HashSet<ModelId>,
|
||||
_refresh_models_task: Task<()>,
|
||||
_settings_subscription: Subscription,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
|
||||
@@ -102,6 +104,19 @@ impl AcpModelPickerDelegate {
|
||||
})
|
||||
};
|
||||
|
||||
let agent_server_for_subscription = agent_server.clone();
|
||||
let settings_subscription =
|
||||
cx.observe_global_in::<SettingsStore>(window, move |picker, window, cx| {
|
||||
// Only refresh if the favorites actually changed to avoid redundant work
|
||||
// when other settings are modified (e.g., user editing settings.json)
|
||||
let new_favorites = agent_server_for_subscription.favorite_model_ids(cx);
|
||||
if new_favorites != picker.delegate.favorites {
|
||||
picker.delegate.favorites = new_favorites;
|
||||
picker.refresh(window, cx);
|
||||
}
|
||||
});
|
||||
let favorites = agent_server.favorite_model_ids(cx);
|
||||
|
||||
Self {
|
||||
selector,
|
||||
agent_server,
|
||||
@@ -111,7 +126,9 @@ impl AcpModelPickerDelegate {
|
||||
selected_model: None,
|
||||
selected_index: 0,
|
||||
selected_description: None,
|
||||
favorites,
|
||||
_refresh_models_task: refresh_models_task,
|
||||
_settings_subscription: settings_subscription,
|
||||
focus_handle,
|
||||
}
|
||||
}
|
||||
@@ -120,40 +137,37 @@ impl AcpModelPickerDelegate {
|
||||
self.selected_model.as_ref()
|
||||
}
|
||||
|
||||
pub fn favorites_count(&self) -> usize {
|
||||
self.favorites.len()
|
||||
}
|
||||
|
||||
pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if !self.selector.supports_favorites() {
|
||||
if self.favorites.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let favorites = AgentSettings::get_global(cx).favorite_model_ids();
|
||||
|
||||
if favorites.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(models) = self.models.clone() else {
|
||||
let Some(models) = &self.models else {
|
||||
return;
|
||||
};
|
||||
|
||||
let all_models: Vec<AgentModelInfo> = match models {
|
||||
AgentModelList::Flat(list) => list,
|
||||
AgentModelList::Grouped(index_map) => index_map
|
||||
.into_values()
|
||||
.flatten()
|
||||
.collect::<Vec<AgentModelInfo>>(),
|
||||
let all_models: Vec<&AgentModelInfo> = match models {
|
||||
AgentModelList::Flat(list) => list.iter().collect(),
|
||||
AgentModelList::Grouped(index_map) => index_map.values().flatten().collect(),
|
||||
};
|
||||
|
||||
let favorite_models = all_models
|
||||
.iter()
|
||||
.filter(|model| favorites.contains(&model.id))
|
||||
let favorite_models: Vec<_> = all_models
|
||||
.into_iter()
|
||||
.filter(|model| self.favorites.contains(&model.id))
|
||||
.unique_by(|model| &model.id)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
.collect();
|
||||
|
||||
let current_id = self.selected_model.as_ref().map(|m| m.id.clone());
|
||||
if favorite_models.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let current_id = self.selected_model.as_ref().map(|m| &m.id);
|
||||
|
||||
let current_index_in_favorites = current_id
|
||||
.as_ref()
|
||||
.and_then(|id| favorite_models.iter().position(|m| &m.id == id))
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
@@ -220,11 +234,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let favorites = if self.selector.supports_favorites() {
|
||||
AgentSettings::get_global(cx).favorite_model_ids()
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
let favorites = self.favorites.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let filtered_models = match this
|
||||
@@ -317,21 +327,20 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
let default_model = self.agent_server.default_model(cx);
|
||||
let is_default = default_model.as_ref() == Some(&model_info.id);
|
||||
|
||||
let supports_favorites = self.selector.supports_favorites();
|
||||
|
||||
let is_favorite = *is_favorite;
|
||||
let handle_action_click = {
|
||||
let model_id = model_info.id.clone();
|
||||
let fs = self.fs.clone();
|
||||
let agent_server = self.agent_server.clone();
|
||||
|
||||
move |cx: &App| {
|
||||
crate::favorite_models::toggle_model_id_in_settings(
|
||||
cx.listener(move |_, _, _, cx| {
|
||||
agent_server.toggle_favorite_model(
|
||||
model_id.clone(),
|
||||
!is_favorite,
|
||||
fs.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Some(
|
||||
@@ -357,10 +366,8 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
})
|
||||
.is_selected(is_selected)
|
||||
.is_focused(selected)
|
||||
.when(supports_favorites, |this| {
|
||||
this.is_favorite(is_favorite)
|
||||
.on_toggle_favorite(handle_action_click)
|
||||
}),
|
||||
.is_favorite(is_favorite)
|
||||
.on_toggle_favorite(handle_action_click),
|
||||
)
|
||||
.into_any_element(),
|
||||
)
|
||||
@@ -603,6 +610,46 @@ mod tests {
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fuzzy_match(cx: &mut TestAppContext) {
|
||||
let models = create_model_list(vec![
|
||||
(
|
||||
"zed",
|
||||
vec![
|
||||
"Claude 3.7 Sonnet",
|
||||
"Claude 3.7 Sonnet Thinking",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
),
|
||||
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
|
||||
("ollama", vec!["mistral", "deepseek"]),
|
||||
]);
|
||||
|
||||
// Results should preserve models order whenever possible.
|
||||
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
||||
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
||||
// so it should appear first in the results.
|
||||
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
|
||||
// Fuzzy search
|
||||
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
|
||||
let models = create_model_list(vec![
|
||||
@@ -739,42 +786,48 @@ mod tests {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fuzzy_match(cx: &mut TestAppContext) {
|
||||
let models = create_model_list(vec![
|
||||
(
|
||||
"zed",
|
||||
vec![
|
||||
"Claude 3.7 Sonnet",
|
||||
"Claude 3.7 Sonnet Thinking",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
),
|
||||
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
|
||||
("ollama", vec!["mistral", "deepseek"]),
|
||||
fn test_favorites_count_returns_correct_count(_cx: &mut TestAppContext) {
|
||||
let empty_favorites: HashSet<ModelId> = HashSet::default();
|
||||
assert_eq!(empty_favorites.len(), 0);
|
||||
|
||||
let one_favorite = create_favorites(vec!["model-a"]);
|
||||
assert_eq!(one_favorite.len(), 1);
|
||||
|
||||
let multiple_favorites = create_favorites(vec!["model-a", "model-b", "model-c"]);
|
||||
assert_eq!(multiple_favorites.len(), 3);
|
||||
|
||||
let with_duplicates = create_favorites(vec!["model-a", "model-a", "model-b"]);
|
||||
assert_eq!(with_duplicates.len(), 2);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_is_favorite_flag_set_correctly_in_entries(_cx: &mut TestAppContext) {
|
||||
let models = AgentModelList::Flat(vec![
|
||||
acp_thread::AgentModelInfo {
|
||||
id: acp::ModelId::new("favorite-model".to_string()),
|
||||
name: "Favorite".into(),
|
||||
description: None,
|
||||
icon: None,
|
||||
},
|
||||
acp_thread::AgentModelInfo {
|
||||
id: acp::ModelId::new("regular-model".to_string()),
|
||||
name: "Regular".into(),
|
||||
description: None,
|
||||
icon: None,
|
||||
},
|
||||
]);
|
||||
let favorites = create_favorites(vec!["favorite-model"]);
|
||||
|
||||
// Results should preserve models order whenever possible.
|
||||
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
||||
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
||||
// so it should appear first in the results.
|
||||
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
let entries = info_list_to_picker_entries(models, &favorites);
|
||||
|
||||
// Fuzzy search
|
||||
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
for entry in &entries {
|
||||
if let AcpModelPickerEntry::Model(info, is_favorite) = entry {
|
||||
if info.id.0.as_ref() == "favorite-model" {
|
||||
assert!(*is_favorite, "favorite-model should have is_favorite=true");
|
||||
} else if info.id.0.as_ref() == "regular-model" {
|
||||
assert!(!*is_favorite, "regular-model should have is_favorite=false");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,17 +2,13 @@ use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::AgentSettings;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
use picker::popover_menu::PickerPopoverMenu;
|
||||
use settings::Settings as _;
|
||||
use ui::{ButtonLike, KeyBinding, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
use ui::{ButtonLike, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
|
||||
|
||||
use crate::CycleFavoriteModels;
|
||||
use crate::acp::{AcpModelSelector, model_selector::acp_model_selector};
|
||||
use crate::ui::ModelSelectorTooltip;
|
||||
|
||||
pub struct AcpModelSelectorPopover {
|
||||
selector: Entity<AcpModelSelector>,
|
||||
@@ -23,7 +19,7 @@ pub struct AcpModelSelectorPopover {
|
||||
impl AcpModelSelectorPopover {
|
||||
pub(crate) fn new(
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
agent_server: Rc<dyn AgentServer>,
|
||||
agent_server: Rc<dyn agent_servers::AgentServer>,
|
||||
fs: Arc<dyn Fs>,
|
||||
menu_handle: PopoverMenuHandle<AcpModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
@@ -64,7 +60,8 @@ impl AcpModelSelectorPopover {
|
||||
|
||||
impl Render for AcpModelSelectorPopover {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let model = self.selector.read(cx).delegate.active_model();
|
||||
let selector = self.selector.read(cx);
|
||||
let model = selector.delegate.active_model();
|
||||
let model_name = model
|
||||
.as_ref()
|
||||
.map(|model| model.name.clone())
|
||||
@@ -80,43 +77,13 @@ impl Render for AcpModelSelectorPopover {
|
||||
(Color::Muted, IconName::ChevronDown)
|
||||
};
|
||||
|
||||
let tooltip = Tooltip::element({
|
||||
move |_, cx| {
|
||||
let focus_handle = focus_handle.clone();
|
||||
let should_show_cycle_row = !AgentSettings::get_global(cx)
|
||||
.favorite_model_ids()
|
||||
.is_empty();
|
||||
let show_cycle_row = selector.delegate.favorites_count() > 1;
|
||||
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(Label::new("Change Model"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&ToggleModelSelector,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
.when(should_show_cycle_row, |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.pt_1()
|
||||
.gap_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.justify_between()
|
||||
.child(Label::new("Cycle Favorited Models"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&CycleFavoriteModels,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
let tooltip = Tooltip::element({
|
||||
move |_, _cx| {
|
||||
ModelSelectorTooltip::new(focus_handle.clone())
|
||||
.show_cycle_row(show_cycle_row)
|
||||
.into_any_element()
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -24,11 +24,11 @@ use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem,
|
||||
CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length,
|
||||
ListOffset, ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task,
|
||||
TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
|
||||
ease_in_out, linear_color_stop, linear_gradient, list, point, pulsating_between,
|
||||
Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, CursorStyle,
|
||||
EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset,
|
||||
ListState, PlatformDisplay, SharedString, StyleRefinement, Subscription, Task, TextStyle,
|
||||
TextStyleRefinement, UnderlineStyle, WeakEntity, Window, WindowHandle, div, ease_in_out,
|
||||
linear_color_stop, linear_gradient, list, point, pulsating_between,
|
||||
};
|
||||
use language::Buffer;
|
||||
|
||||
@@ -47,14 +47,16 @@ use terminal_view::terminal_panel::TerminalPanel;
|
||||
use text::Anchor;
|
||||
use theme::{AgentFontSize, ThemeSettings};
|
||||
use ui::{
|
||||
Callout, CommonAnimationExt, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding,
|
||||
PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip, WithScrollbar, prelude::*,
|
||||
Callout, CommonAnimationExt, ContextMenu, ContextMenuEntry, CopyButton, Disclosure, Divider,
|
||||
DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle, SpinnerLabel, TintColor, Tooltip,
|
||||
WithScrollbar, prelude::*, right_click_menu,
|
||||
};
|
||||
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
|
||||
use workspace::{CollaboratorId, NewTerminal, Workspace};
|
||||
use zed_actions::agent::{Chat, ToggleModelSelector};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
use super::config_options::ConfigOptionsView;
|
||||
use super::entry_view_state::EntryViewState;
|
||||
use crate::acp::AcpModelSelectorPopover;
|
||||
use crate::acp::ModeSelector;
|
||||
@@ -271,12 +273,14 @@ pub struct AcpThreadView {
|
||||
message_editor: Entity<MessageEditor>,
|
||||
focus_handle: FocusHandle,
|
||||
model_selector: Option<Entity<AcpModelSelectorPopover>>,
|
||||
config_options_view: Option<Entity<ConfigOptionsView>>,
|
||||
profile_selector: Option<Entity<ProfileSelector>>,
|
||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
thread_retry_status: Option<RetryStatus>,
|
||||
thread_error: Option<ThreadError>,
|
||||
thread_error_markdown: Option<Entity<Markdown>>,
|
||||
token_limit_callout_dismissed: bool,
|
||||
thread_feedback: ThreadFeedbackState,
|
||||
list_state: ListState,
|
||||
auth_task: Option<Task<()>>,
|
||||
@@ -428,14 +432,15 @@ impl AcpThreadView {
|
||||
login: None,
|
||||
message_editor,
|
||||
model_selector: None,
|
||||
config_options_view: None,
|
||||
profile_selector: None,
|
||||
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
list_state: list_state,
|
||||
thread_retry_status: None,
|
||||
thread_error: None,
|
||||
thread_error_markdown: None,
|
||||
token_limit_callout_dismissed: false,
|
||||
thread_feedback: Default::default(),
|
||||
auth_task: None,
|
||||
expanded_tool_calls: HashSet::default(),
|
||||
@@ -612,42 +617,64 @@ impl AcpThreadView {
|
||||
|
||||
AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
|
||||
|
||||
this.model_selector = thread
|
||||
// Check for config options first
|
||||
// Config options take precedence over legacy mode/model selectors
|
||||
// (feature flag gating happens at the data layer)
|
||||
let config_options_provider = thread
|
||||
.read(cx)
|
||||
.connection()
|
||||
.model_selector(thread.read(cx).session_id())
|
||||
.map(|selector| {
|
||||
let agent_server = this.agent.clone();
|
||||
let fs = this.project.read(cx).fs().clone();
|
||||
cx.new(|cx| {
|
||||
AcpModelSelectorPopover::new(
|
||||
selector,
|
||||
agent_server,
|
||||
fs,
|
||||
PopoverMenuHandle::default(),
|
||||
this.focus_handle(cx),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
.session_config_options(thread.read(cx).session_id(), cx);
|
||||
|
||||
let mode_selector = thread
|
||||
.read(cx)
|
||||
.connection()
|
||||
.session_modes(thread.read(cx).session_id(), cx)
|
||||
.map(|session_modes| {
|
||||
let fs = this.project.read(cx).fs().clone();
|
||||
let focus_handle = this.focus_handle(cx);
|
||||
cx.new(|_cx| {
|
||||
ModeSelector::new(
|
||||
session_modes,
|
||||
this.agent.clone(),
|
||||
fs,
|
||||
focus_handle,
|
||||
)
|
||||
})
|
||||
});
|
||||
let mode_selector;
|
||||
if let Some(config_options) = config_options_provider {
|
||||
// Use config options - don't create mode_selector or model_selector
|
||||
let agent_server = this.agent.clone();
|
||||
let fs = this.project.read(cx).fs().clone();
|
||||
this.config_options_view = Some(cx.new(|cx| {
|
||||
ConfigOptionsView::new(config_options, agent_server, fs, window, cx)
|
||||
}));
|
||||
this.model_selector = None;
|
||||
mode_selector = None;
|
||||
} else {
|
||||
// Fall back to legacy mode/model selectors
|
||||
this.config_options_view = None;
|
||||
this.model_selector = thread
|
||||
.read(cx)
|
||||
.connection()
|
||||
.model_selector(thread.read(cx).session_id())
|
||||
.map(|selector| {
|
||||
let agent_server = this.agent.clone();
|
||||
let fs = this.project.read(cx).fs().clone();
|
||||
cx.new(|cx| {
|
||||
AcpModelSelectorPopover::new(
|
||||
selector,
|
||||
agent_server,
|
||||
fs,
|
||||
PopoverMenuHandle::default(),
|
||||
this.focus_handle(cx),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
mode_selector = thread
|
||||
.read(cx)
|
||||
.connection()
|
||||
.session_modes(thread.read(cx).session_id(), cx)
|
||||
.map(|session_modes| {
|
||||
let fs = this.project.read(cx).fs().clone();
|
||||
let focus_handle = this.focus_handle(cx);
|
||||
cx.new(|_cx| {
|
||||
ModeSelector::new(
|
||||
session_modes,
|
||||
this.agent.clone(),
|
||||
fs,
|
||||
focus_handle,
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
let mut subscriptions = vec![
|
||||
cx.subscribe_in(&thread, window, Self::handle_thread_event),
|
||||
@@ -1393,6 +1420,7 @@ impl AcpThreadView {
|
||||
fn clear_thread_error(&mut self, cx: &mut Context<Self>) {
|
||||
self.thread_error = None;
|
||||
self.thread_error_markdown = None;
|
||||
self.token_limit_callout_dismissed = true;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -1519,6 +1547,10 @@ impl AcpThreadView {
|
||||
// The connection keeps track of the mode
|
||||
cx.notify();
|
||||
}
|
||||
AcpThreadEvent::ConfigOptionsUpdated(_) => {
|
||||
// The watch task in ConfigOptionsView handles rebuilding selectors
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -2038,7 +2070,7 @@ impl AcpThreadView {
|
||||
}
|
||||
})
|
||||
.text_xs()
|
||||
.child(editor.clone().into_any_element()),
|
||||
.child(editor.clone().into_any_element())
|
||||
)
|
||||
.when(editor_focus, |this| {
|
||||
let base_container = h_flex()
|
||||
@@ -2154,7 +2186,6 @@ impl AcpThreadView {
|
||||
if this_is_blank {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
self.render_thinking_block(
|
||||
entry_ix,
|
||||
@@ -2180,7 +2211,7 @@ impl AcpThreadView {
|
||||
.when(is_last, |this| this.pb_4())
|
||||
.w_full()
|
||||
.text_ui(cx)
|
||||
.child(message_body)
|
||||
.child(self.render_message_context_menu(entry_ix, message_body, cx))
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
@@ -2287,6 +2318,70 @@ impl AcpThreadView {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_message_context_menu(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
message_body: AnyElement,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
let entity = cx.entity();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
right_click_menu(format!("agent_context_menu-{}", entry_ix))
|
||||
.trigger(move |_, _, _| message_body)
|
||||
.menu(move |window, cx| {
|
||||
let focus = window.focused(cx);
|
||||
let entity = entity.clone();
|
||||
let workspace = workspace.clone();
|
||||
|
||||
ContextMenu::build(window, cx, move |menu, _, cx| {
|
||||
let is_at_top = entity.read(cx).list_state.logical_scroll_top().item_ix == 0;
|
||||
|
||||
let scroll_item = if is_at_top {
|
||||
ContextMenuEntry::new("Scroll to Bottom").handler({
|
||||
let entity = entity.clone();
|
||||
move |_, cx| {
|
||||
entity.update(cx, |this, cx| {
|
||||
this.scroll_to_bottom(cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
} else {
|
||||
ContextMenuEntry::new("Scroll to Top").handler({
|
||||
let entity = entity.clone();
|
||||
move |_, cx| {
|
||||
entity.update(cx, |this, cx| {
|
||||
this.scroll_to_top(cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let open_thread_as_markdown = ContextMenuEntry::new("Open Thread as Markdown")
|
||||
.handler({
|
||||
let entity = entity.clone();
|
||||
let workspace = workspace.clone();
|
||||
move |window, cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
entity
|
||||
.update(cx, |this, cx| {
|
||||
this.open_thread_as_markdown(workspace, window, cx)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
menu.when_some(focus, |menu, focus| menu.context(focus))
|
||||
.action("Copy", Box::new(markdown::CopyAsMarkdown))
|
||||
.separator()
|
||||
.item(scroll_item)
|
||||
.item(open_thread_as_markdown)
|
||||
})
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn tool_card_header_bg(&self, cx: &Context<Self>) -> Hsla {
|
||||
cx.theme()
|
||||
.colors()
|
||||
@@ -4288,37 +4383,6 @@ impl AcpThreadView {
|
||||
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::expand_message_editor))
|
||||
.on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| {
|
||||
if let Some(profile_selector) = this.profile_selector.as_ref() {
|
||||
profile_selector.read(cx).menu_handle().toggle(window, cx);
|
||||
} else if let Some(mode_selector) = this.mode_selector() {
|
||||
mode_selector.read(cx).menu_handle().toggle(window, cx);
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &CycleModeSelector, window, cx| {
|
||||
if let Some(profile_selector) = this.profile_selector.as_ref() {
|
||||
profile_selector.update(cx, |profile_selector, cx| {
|
||||
profile_selector.cycle_profile(cx);
|
||||
});
|
||||
} else if let Some(mode_selector) = this.mode_selector() {
|
||||
mode_selector.update(cx, |mode_selector, cx| {
|
||||
mode_selector.cycle_mode(window, cx);
|
||||
});
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
|
||||
if let Some(model_selector) = this.model_selector.as_ref() {
|
||||
model_selector
|
||||
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &CycleFavoriteModels, window, cx| {
|
||||
if let Some(model_selector) = this.model_selector.as_ref() {
|
||||
model_selector.update(cx, |model_selector, cx| {
|
||||
model_selector.cycle_favorite_models(window, cx);
|
||||
});
|
||||
}
|
||||
}))
|
||||
.p_2()
|
||||
.gap_2()
|
||||
.border_t_1()
|
||||
@@ -4382,8 +4446,12 @@ impl AcpThreadView {
|
||||
.gap_1()
|
||||
.children(self.render_token_usage(cx))
|
||||
.children(self.profile_selector.clone())
|
||||
.children(self.mode_selector().cloned())
|
||||
.children(self.model_selector.clone())
|
||||
// Either config_options_view OR (mode_selector + model_selector)
|
||||
.children(self.config_options_view.clone())
|
||||
.when(self.config_options_view.is_none(), |this| {
|
||||
this.children(self.mode_selector().cloned())
|
||||
.children(self.model_selector.clone())
|
||||
})
|
||||
.child(self.render_send_button(cx)),
|
||||
),
|
||||
)
|
||||
@@ -5358,22 +5426,26 @@ impl AcpThreadView {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_token_limit_callout(
|
||||
&self,
|
||||
line_height: Pixels,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Callout> {
|
||||
fn render_token_limit_callout(&self, cx: &mut Context<Self>) -> Option<Callout> {
|
||||
if self.token_limit_callout_dismissed {
|
||||
return None;
|
||||
}
|
||||
|
||||
let token_usage = self.thread()?.read(cx).token_usage()?;
|
||||
let ratio = token_usage.ratio();
|
||||
|
||||
let (severity, title) = match ratio {
|
||||
let (severity, icon, title) = match ratio {
|
||||
acp_thread::TokenUsageRatio::Normal => return None,
|
||||
acp_thread::TokenUsageRatio::Warning => {
|
||||
(Severity::Warning, "Thread reaching the token limit soon")
|
||||
}
|
||||
acp_thread::TokenUsageRatio::Exceeded => {
|
||||
(Severity::Error, "Thread reached the token limit")
|
||||
}
|
||||
acp_thread::TokenUsageRatio::Warning => (
|
||||
Severity::Warning,
|
||||
IconName::Warning,
|
||||
"Thread reaching the token limit soon",
|
||||
),
|
||||
acp_thread::TokenUsageRatio::Exceeded => (
|
||||
Severity::Error,
|
||||
IconName::XCircle,
|
||||
"Thread reached the token limit",
|
||||
),
|
||||
};
|
||||
|
||||
let burn_mode_available = self.as_native_thread(cx).is_some_and(|thread| {
|
||||
@@ -5393,7 +5465,7 @@ impl AcpThreadView {
|
||||
Some(
|
||||
Callout::new()
|
||||
.severity(severity)
|
||||
.line_height(line_height)
|
||||
.icon(icon)
|
||||
.title(title)
|
||||
.description(description)
|
||||
.actions_slot(
|
||||
@@ -5425,7 +5497,8 @@ impl AcpThreadView {
|
||||
})),
|
||||
)
|
||||
}),
|
||||
),
|
||||
)
|
||||
.dismiss_action(self.dismiss_error_button(cx)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5848,18 +5921,13 @@ impl AcpThreadView {
|
||||
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||
let message = message.into();
|
||||
|
||||
IconButton::new("copy", IconName::Copy)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Copy Error Message"))
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||
})
|
||||
CopyButton::new(message).tooltip_label("Copy Error Message")
|
||||
}
|
||||
|
||||
fn dismiss_error_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text("Dismiss Error"))
|
||||
.tooltip(Tooltip::text("Dismiss"))
|
||||
.on_click(cx.listener({
|
||||
move |this, _, _, cx| {
|
||||
this.clear_thread_error(cx);
|
||||
@@ -6005,6 +6073,37 @@ impl Render for AcpThreadView {
|
||||
.on_action(cx.listener(Self::allow_always))
|
||||
.on_action(cx.listener(Self::allow_once))
|
||||
.on_action(cx.listener(Self::reject_once))
|
||||
.on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| {
|
||||
if let Some(profile_selector) = this.profile_selector.as_ref() {
|
||||
profile_selector.read(cx).menu_handle().toggle(window, cx);
|
||||
} else if let Some(mode_selector) = this.mode_selector() {
|
||||
mode_selector.read(cx).menu_handle().toggle(window, cx);
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &CycleModeSelector, window, cx| {
|
||||
if let Some(profile_selector) = this.profile_selector.as_ref() {
|
||||
profile_selector.update(cx, |profile_selector, cx| {
|
||||
profile_selector.cycle_profile(cx);
|
||||
});
|
||||
} else if let Some(mode_selector) = this.mode_selector() {
|
||||
mode_selector.update(cx, |mode_selector, cx| {
|
||||
mode_selector.cycle_mode(window, cx);
|
||||
});
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
|
||||
if let Some(model_selector) = this.model_selector.as_ref() {
|
||||
model_selector
|
||||
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &CycleFavoriteModels, window, cx| {
|
||||
if let Some(model_selector) = this.model_selector.as_ref() {
|
||||
model_selector.update(cx, |model_selector, cx| {
|
||||
model_selector.cycle_favorite_models(window, cx);
|
||||
});
|
||||
}
|
||||
}))
|
||||
.track_focus(&self.focus_handle)
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.child(match &self.thread_state {
|
||||
@@ -6088,7 +6187,7 @@ impl Render for AcpThreadView {
|
||||
if let Some(usage_callout) = self.render_usage_callout(line_height, cx) {
|
||||
Some(usage_callout.into_any_element())
|
||||
} else {
|
||||
self.render_token_limit_callout(line_height, cx)
|
||||
self.render_token_limit_callout(cx)
|
||||
.map(|token_limit_callout| token_limit_callout.into_any_element())
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1370,6 +1370,9 @@ async fn open_new_agent_servers_entry_in_settings_editor(
|
||||
env: Some(HashMap::default()),
|
||||
default_mode: None,
|
||||
default_model: None,
|
||||
favorite_models: vec![],
|
||||
default_config_options: Default::default(),
|
||||
favorite_config_option_values: Default::default(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1363,7 +1363,8 @@ impl AgentDiff {
|
||||
| AcpThreadEvent::PromptCapabilitiesUpdated
|
||||
| AcpThreadEvent::AvailableCommandsUpdated(_)
|
||||
| AcpThreadEvent::Retry(_)
|
||||
| AcpThreadEvent::ModeUpdated(_) => {}
|
||||
| AcpThreadEvent::ModeUpdated(_)
|
||||
| AcpThreadEvent::ConfigOptionsUpdated(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
ModelUsageContext,
|
||||
language_model_selector::{LanguageModelSelector, language_model_selector},
|
||||
ui::ModelSelectorTooltip,
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle, SharedString};
|
||||
@@ -9,7 +10,6 @@ use picker::popover_menu::PickerPopoverMenu;
|
||||
use settings::update_settings_file;
|
||||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
|
||||
pub struct AgentModelSelector {
|
||||
selector: Entity<LanguageModelSelector>,
|
||||
@@ -81,6 +81,12 @@ impl AgentModelSelector {
|
||||
pub fn active_model(&self, cx: &App) -> Option<language_model::ConfiguredModel> {
|
||||
self.selector.read(cx).delegate.active_model(cx)
|
||||
}
|
||||
|
||||
pub fn cycle_favorite_models(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.selector.update(cx, |selector, cx| {
|
||||
selector.delegate.cycle_favorite_models(window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AgentModelSelector {
|
||||
@@ -98,8 +104,18 @@ impl Render for AgentModelSelector {
|
||||
Color::Muted
|
||||
};
|
||||
|
||||
let show_cycle_row = self.selector.read(cx).delegate.favorites_count() > 1;
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let tooltip = Tooltip::element({
|
||||
move |_, _cx| {
|
||||
ModelSelectorTooltip::new(focus_handle.clone())
|
||||
.show_cycle_row(show_cycle_row)
|
||||
.into_any_element()
|
||||
}
|
||||
});
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
@@ -125,9 +141,7 @@ impl Render for AgentModelSelector {
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
|
||||
},
|
||||
tooltip,
|
||||
gpui::Corner::TopRight,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_client_protocol::ModelId;
|
||||
use fs::Fs;
|
||||
use language_model::LanguageModel;
|
||||
use settings::{LanguageModelSelection, update_settings_file};
|
||||
@@ -13,20 +12,11 @@ fn language_model_to_selection(model: &Arc<dyn LanguageModel>) -> LanguageModelS
|
||||
}
|
||||
}
|
||||
|
||||
fn model_id_to_selection(model_id: &ModelId) -> LanguageModelSelection {
|
||||
let id = model_id.0.as_ref();
|
||||
let (provider, model) = id.split_once('/').unwrap_or(("", id));
|
||||
LanguageModelSelection {
|
||||
provider: provider.to_owned().into(),
|
||||
model: model.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn toggle_in_settings(
|
||||
model: Arc<dyn LanguageModel>,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let selection = language_model_to_selection(&model);
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
@@ -38,20 +28,3 @@ pub fn toggle_in_settings(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn toggle_model_id_in_settings(
|
||||
model_id: ModelId,
|
||||
should_be_favorite: bool,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) {
|
||||
let selection = model_id_to_selection(&model_id);
|
||||
update_settings_file(fs, cx, move |settings, _| {
|
||||
let agent = settings.agent.get_or_insert_default();
|
||||
if should_be_favorite {
|
||||
agent.add_favorite_model(selection.clone());
|
||||
} else {
|
||||
agent.remove_favorite_model(&selection);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -40,7 +40,9 @@ use crate::completion_provider::{
|
||||
use crate::mention_set::paste_images_as_context;
|
||||
use crate::mention_set::{MentionSet, crease_for_mention};
|
||||
use crate::terminal_codegen::TerminalCodegen;
|
||||
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
|
||||
use crate::{
|
||||
CycleFavoriteModels, CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext,
|
||||
};
|
||||
|
||||
actions!(inline_assistant, [ThumbsUpResult, ThumbsDownResult]);
|
||||
|
||||
@@ -148,7 +150,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.into_any_element();
|
||||
|
||||
v_flex()
|
||||
.key_context("PromptEditor")
|
||||
.key_context("InlineAssistant")
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.block_mouse_except_scroll()
|
||||
.size_full()
|
||||
@@ -162,10 +164,6 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
|
||||
this.model_selector
|
||||
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
|
||||
}))
|
||||
.on_action(cx.listener(Self::confirm))
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::move_up))
|
||||
@@ -174,6 +172,15 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.on_action(cx.listener(Self::thumbs_down))
|
||||
.capture_action(cx.listener(Self::cycle_prev))
|
||||
.capture_action(cx.listener(Self::cycle_next))
|
||||
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
|
||||
this.model_selector
|
||||
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &CycleFavoriteModels, window, cx| {
|
||||
this.model_selector.update(cx, |model_selector, cx| {
|
||||
model_selector.cycle_favorite_models(window, cx);
|
||||
});
|
||||
}))
|
||||
.child(
|
||||
WithRemSize::new(ui_font_size)
|
||||
.h_full()
|
||||
@@ -855,7 +862,7 @@ impl<T: 'static> PromptEditor<T> {
|
||||
.map(|this| {
|
||||
if rated {
|
||||
this.disabled(true)
|
||||
.icon_color(Color::Ignored)
|
||||
.icon_color(Color::Disabled)
|
||||
.tooltip(move |_, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Good Result",
|
||||
@@ -865,8 +872,15 @@ impl<T: 'static> PromptEditor<T> {
|
||||
)
|
||||
})
|
||||
} else {
|
||||
this.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Good Result"))
|
||||
this.icon_color(Color::Muted).tooltip(
|
||||
move |_, cx| {
|
||||
Tooltip::for_action(
|
||||
"Good Result",
|
||||
&ThumbsUpResult,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
@@ -879,7 +893,7 @@ impl<T: 'static> PromptEditor<T> {
|
||||
.map(|this| {
|
||||
if rated {
|
||||
this.disabled(true)
|
||||
.icon_color(Color::Ignored)
|
||||
.icon_color(Color::Disabled)
|
||||
.tooltip(move |_, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Bad Result",
|
||||
@@ -889,8 +903,15 @@ impl<T: 'static> PromptEditor<T> {
|
||||
)
|
||||
})
|
||||
} else {
|
||||
this.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Bad Result"))
|
||||
this.icon_color(Color::Muted).tooltip(
|
||||
move |_, cx| {
|
||||
Tooltip::for_action(
|
||||
"Bad Result",
|
||||
&ThumbsDownResult,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
@@ -1088,7 +1109,6 @@ impl<T: 'static> PromptEditor<T> {
|
||||
let colors = cx.theme().colors();
|
||||
|
||||
div()
|
||||
.key_context("InlineAssistEditor")
|
||||
.size_full()
|
||||
.p_2()
|
||||
.pl_1()
|
||||
|
||||
@@ -20,14 +20,14 @@ use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem}
|
||||
|
||||
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
|
||||
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
|
||||
type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &App) + 'static>;
|
||||
type OnToggleFavorite = Arc<dyn Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static>;
|
||||
|
||||
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
|
||||
|
||||
pub fn language_model_selector(
|
||||
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
||||
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
||||
on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
|
||||
on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
window: &mut Window,
|
||||
@@ -133,7 +133,7 @@ impl LanguageModelPickerDelegate {
|
||||
fn new(
|
||||
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
|
||||
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
|
||||
on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &App) + 'static,
|
||||
on_toggle_favorite: impl Fn(Arc<dyn LanguageModel>, bool, &mut App) + 'static,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
window: &mut Window,
|
||||
@@ -250,6 +250,10 @@ impl LanguageModelPickerDelegate {
|
||||
(self.get_active_model)(cx)
|
||||
}
|
||||
|
||||
pub fn favorites_count(&self) -> usize {
|
||||
self.all_models.favorites.len()
|
||||
}
|
||||
|
||||
pub fn cycle_favorite_models(&mut self, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if self.all_models.favorites.is_empty() {
|
||||
return;
|
||||
@@ -561,7 +565,10 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
let handle_action_click = {
|
||||
let model = model_info.model.clone();
|
||||
let on_toggle_favorite = self.on_toggle_favorite.clone();
|
||||
move |cx: &App| on_toggle_favorite(model.clone(), !is_favorite, cx)
|
||||
cx.listener(move |picker, _, window, cx| {
|
||||
on_toggle_favorite(model.clone(), !is_favorite, cx);
|
||||
picker.refresh(window, cx);
|
||||
})
|
||||
};
|
||||
|
||||
Some(
|
||||
|
||||
@@ -12,8 +12,8 @@ use editor::{
|
||||
};
|
||||
use futures::{AsyncReadExt as _, FutureExt as _, future::Shared};
|
||||
use gpui::{
|
||||
Animation, AnimationExt as _, AppContext, ClipboardEntry, Context, Empty, Entity, EntityId,
|
||||
Image, ImageFormat, Img, SharedString, Task, WeakEntity, pulsating_between,
|
||||
AppContext, ClipboardEntry, Context, Empty, Entity, EntityId, Image, ImageFormat, Img,
|
||||
SharedString, Task, WeakEntity,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use itertools::Either;
|
||||
@@ -32,13 +32,14 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{ButtonLike, Disclosure, TintColor, Toggleable, prelude::*};
|
||||
use ui::{Disclosure, Toggleable, prelude::*};
|
||||
use util::{ResultExt, debug_panic, rel_path::RelPath};
|
||||
use workspace::{Workspace, notifications::NotifyResultExt as _};
|
||||
|
||||
use crate::ui::MentionCrease;
|
||||
|
||||
pub type MentionTask = Shared<Task<Result<Mention, String>>>;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
@@ -754,25 +755,8 @@ fn render_fold_icon_button(
|
||||
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
|
||||
.unwrap_or_default();
|
||||
|
||||
ButtonLike::new(fold_id)
|
||||
.style(ButtonStyle::Filled)
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.toggle_state(is_in_text_selection)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Icon::from_path(icon_path.clone())
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new(label.clone())
|
||||
.size(LabelSize::Small)
|
||||
.buffer_font(cx)
|
||||
.single_line(),
|
||||
),
|
||||
)
|
||||
MentionCrease::new(fold_id, icon_path.clone(), label.clone())
|
||||
.is_toggled(is_in_text_selection)
|
||||
.into_any_element()
|
||||
}
|
||||
})
|
||||
@@ -947,12 +931,14 @@ impl Render for LoadingContext {
|
||||
.editor
|
||||
.update(cx, |editor, cx| editor.is_range_selected(&self.range, cx))
|
||||
.unwrap_or_default();
|
||||
ButtonLike::new(("loading-context", self.id))
|
||||
.style(ButtonStyle::Filled)
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.toggle_state(is_in_text_selection)
|
||||
.when_some(self.image.clone(), |el, image_task| {
|
||||
el.hoverable_tooltip(move |_, cx| {
|
||||
|
||||
let id = ElementId::from(("loading_context", self.id));
|
||||
|
||||
MentionCrease::new(id, self.icon.clone(), self.label.clone())
|
||||
.is_toggled(is_in_text_selection)
|
||||
.is_loading(self.loading.is_some())
|
||||
.when_some(self.image.clone(), |this, image_task| {
|
||||
this.image_preview(move |_, cx| {
|
||||
let image = image_task.peek().cloned().transpose().ok().flatten();
|
||||
let image_task = image_task.clone();
|
||||
cx.new::<ImageHover>(|cx| ImageHover {
|
||||
@@ -971,35 +957,6 @@ impl Render for LoadingContext {
|
||||
.into()
|
||||
})
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Icon::from_path(self.icon.clone())
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new(self.label.clone())
|
||||
.size(LabelSize::Small)
|
||||
.buffer_font(cx)
|
||||
.single_line(),
|
||||
)
|
||||
.map(|el| {
|
||||
if self.loading.is_some() {
|
||||
el.with_animation(
|
||||
"loading-context-crease",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.opacity(delta),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
el.into_any()
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
language_model_selector::{LanguageModelSelector, language_model_selector},
|
||||
ui::BurnModeTooltip,
|
||||
ui::{BurnModeTooltip, ModelSelectorTooltip},
|
||||
};
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use agent_settings::CompletionMode;
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet};
|
||||
use assistant_slash_commands::{DefaultSlashCommand, FileSlashCommand, selections_creases};
|
||||
@@ -2252,43 +2252,18 @@ impl TextThreadEditor {
|
||||
.color(color)
|
||||
.size(IconSize::XSmall);
|
||||
|
||||
let tooltip = Tooltip::element({
|
||||
move |_, cx| {
|
||||
let focus_handle = focus_handle.clone();
|
||||
let should_show_cycle_row = !AgentSettings::get_global(cx)
|
||||
.favorite_model_ids()
|
||||
.is_empty();
|
||||
let show_cycle_row = self
|
||||
.language_model_selector
|
||||
.read(cx)
|
||||
.delegate
|
||||
.favorites_count()
|
||||
> 1;
|
||||
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(Label::new("Change Model"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&ToggleModelSelector,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
.when(should_show_cycle_row, |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.pt_1()
|
||||
.gap_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.justify_between()
|
||||
.child(Label::new("Cycle Favorited Models"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&CycleFavoriteModels,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
})
|
||||
.into_any()
|
||||
let tooltip = Tooltip::element({
|
||||
move |_, _cx| {
|
||||
ModelSelectorTooltip::new(focus_handle.clone())
|
||||
.show_cycle_row(show_cycle_row)
|
||||
.into_any_element()
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ mod burn_mode_tooltip;
|
||||
mod claude_code_onboarding_modal;
|
||||
mod end_trial_upsell;
|
||||
mod hold_for_default;
|
||||
mod mention_crease;
|
||||
mod model_selector_components;
|
||||
mod onboarding_modal;
|
||||
mod usage_callout;
|
||||
@@ -14,6 +15,7 @@ pub use burn_mode_tooltip::*;
|
||||
pub use claude_code_onboarding_modal::*;
|
||||
pub use end_trial_upsell::*;
|
||||
pub use hold_for_default::*;
|
||||
pub use mention_crease::*;
|
||||
pub use model_selector_components::*;
|
||||
pub use onboarding_modal::*;
|
||||
pub use usage_callout::*;
|
||||
|
||||
100
crates/agent_ui/src/ui/mention_crease.rs
Normal file
100
crates/agent_ui/src/ui/mention_crease.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use gpui::{Animation, AnimationExt, AnyView, IntoElement, Window, pulsating_between};
|
||||
use settings::Settings;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{ButtonLike, TintColor, prelude::*};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct MentionCrease {
|
||||
id: ElementId,
|
||||
icon: SharedString,
|
||||
label: SharedString,
|
||||
is_toggled: bool,
|
||||
is_loading: bool,
|
||||
image_preview: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView + 'static>>,
|
||||
}
|
||||
|
||||
impl MentionCrease {
|
||||
pub fn new(
|
||||
id: impl Into<ElementId>,
|
||||
icon: impl Into<SharedString>,
|
||||
label: impl Into<SharedString>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
icon: icon.into(),
|
||||
label: label.into(),
|
||||
is_toggled: false,
|
||||
is_loading: false,
|
||||
image_preview: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_toggled(mut self, is_toggled: bool) -> Self {
|
||||
self.is_toggled = is_toggled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_loading(mut self, is_loading: bool) -> Self {
|
||||
self.is_loading = is_loading;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn image_preview(
|
||||
mut self,
|
||||
builder: impl Fn(&mut Window, &mut App) -> AnyView + 'static,
|
||||
) -> Self {
|
||||
self.image_preview = Some(Box::new(builder));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for MentionCrease {
|
||||
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = settings.agent_buffer_font_size(cx);
|
||||
let buffer_font = settings.buffer_font.clone();
|
||||
|
||||
let button_height = DefiniteLength::Absolute(AbsoluteLength::Pixels(
|
||||
px(window.line_height().into()) - px(1.),
|
||||
));
|
||||
|
||||
ButtonLike::new(self.id)
|
||||
.style(ButtonStyle::Outlined)
|
||||
.size(ButtonSize::Compact)
|
||||
.height(button_height)
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.toggle_state(self.is_toggled)
|
||||
.when_some(self.image_preview, |this, image_preview| {
|
||||
this.hoverable_tooltip(image_preview)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.pb_px()
|
||||
.gap_1()
|
||||
.font(buffer_font)
|
||||
.text_size(font_size)
|
||||
.child(
|
||||
Icon::from_path(self.icon.clone())
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(self.label.clone())
|
||||
.map(|this| {
|
||||
if self.is_loading {
|
||||
this.with_animation(
|
||||
"loading-context-crease",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.opacity(delta),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
this.into_any()
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
use gpui::{Action, FocusHandle, prelude::*};
|
||||
use gpui::{Action, ClickEvent, FocusHandle, prelude::*};
|
||||
use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
|
||||
use crate::CycleFavoriteModels;
|
||||
|
||||
enum ModelIcon {
|
||||
Name(IconName),
|
||||
@@ -48,7 +51,7 @@ pub struct ModelSelectorListItem {
|
||||
is_selected: bool,
|
||||
is_focused: bool,
|
||||
is_favorite: bool,
|
||||
on_toggle_favorite: Option<Box<dyn Fn(&App) + 'static>>,
|
||||
on_toggle_favorite: Option<Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static>>,
|
||||
}
|
||||
|
||||
impl ModelSelectorListItem {
|
||||
@@ -89,7 +92,10 @@ impl ModelSelectorListItem {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn on_toggle_favorite(mut self, handler: impl Fn(&App) + 'static) -> Self {
|
||||
pub fn on_toggle_favorite(
|
||||
mut self,
|
||||
handler: impl Fn(&ClickEvent, &mut Window, &mut App) + 'static,
|
||||
) -> Self {
|
||||
self.on_toggle_favorite = Some(Box::new(handler));
|
||||
self
|
||||
}
|
||||
@@ -141,7 +147,7 @@ impl RenderOnce for ModelSelectorListItem {
|
||||
.icon_color(color)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip(Tooltip::text(tooltip))
|
||||
.on_click(move |_, _, cx| (handle_click)(cx)),
|
||||
.on_click(move |event, window, cx| (handle_click)(event, window, cx)),
|
||||
)
|
||||
}
|
||||
}))
|
||||
@@ -187,3 +193,57 @@ impl RenderOnce for ModelSelectorFooter {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct ModelSelectorTooltip {
|
||||
focus_handle: FocusHandle,
|
||||
show_cycle_row: bool,
|
||||
}
|
||||
|
||||
impl ModelSelectorTooltip {
|
||||
pub fn new(focus_handle: FocusHandle) -> Self {
|
||||
Self {
|
||||
focus_handle,
|
||||
show_cycle_row: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn show_cycle_row(mut self, show: bool) -> Self {
|
||||
self.show_cycle_row = show;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for ModelSelectorTooltip {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(Label::new("Change Model"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&ToggleModelSelector,
|
||||
&self.focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
.when(self.show_cycle_row, |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.pt_1()
|
||||
.gap_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.justify_between()
|
||||
.child(Label::new("Cycle Favorited Models"))
|
||||
.child(KeyBinding::for_action_in(
|
||||
&CycleFavoriteModels,
|
||||
&self.focus_handle,
|
||||
cx,
|
||||
)),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,6 +314,12 @@ impl BufferDiffSnapshot {
|
||||
self.inner.hunks.is_empty()
|
||||
}
|
||||
|
||||
pub fn base_text_string(&self) -> Option<String> {
|
||||
self.inner
|
||||
.base_text_exists
|
||||
.then(|| self.inner.base_text.text())
|
||||
}
|
||||
|
||||
pub fn secondary_diff(&self) -> Option<&BufferDiffSnapshot> {
|
||||
self.secondary_diff.as_deref()
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ impl CopilotSweAgentBot {
|
||||
const USER_ID: i32 = 198982749;
|
||||
/// The alias of the GitHub copilot user. Although https://api.github.com/users/copilot
|
||||
/// yields a 404, GitHub still refers to the copilot bot user as @Copilot in some cases.
|
||||
const NAME_ALIAS: &'static str = "copilot";
|
||||
const NAME_ALIAS: &'static str = "Copilot";
|
||||
|
||||
/// Returns the `created_at` timestamp for the Dependabot bot user.
|
||||
fn created_at() -> &'static NaiveDateTime {
|
||||
|
||||
@@ -6745,8 +6745,13 @@ async fn test_preview_tabs(cx: &mut TestAppContext) {
|
||||
});
|
||||
|
||||
// Split pane to the right
|
||||
pane.update(cx, |pane, cx| {
|
||||
pane.split(workspace::SplitDirection::Right, cx);
|
||||
pane.update_in(cx, |pane, window, cx| {
|
||||
pane.split(
|
||||
workspace::SplitDirection::Right,
|
||||
workspace::SplitMode::default(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let right_pane = workspace.read_with(cx, |workspace, _| workspace.active_pane().clone());
|
||||
|
||||
@@ -31,9 +31,9 @@ use smallvec::SmallVec;
|
||||
use std::{mem, sync::Arc};
|
||||
use theme::{ActiveTheme, ThemeSettings};
|
||||
use ui::{
|
||||
Avatar, AvatarAvailabilityIndicator, Button, Color, ContextMenu, Facepile, HighlightedLabel,
|
||||
Icon, IconButton, IconName, IconSize, Indicator, Label, ListHeader, ListItem, Tab, Tooltip,
|
||||
prelude::*, tooltip_container,
|
||||
Avatar, AvatarAvailabilityIndicator, Button, Color, ContextMenu, CopyButton, Facepile,
|
||||
HighlightedLabel, Icon, IconButton, IconName, IconSize, Indicator, Label, ListHeader, ListItem,
|
||||
Tab, Tooltip, prelude::*, tooltip_container,
|
||||
};
|
||||
use util::{ResultExt, TryFutureExt, maybe};
|
||||
use workspace::{
|
||||
@@ -2527,16 +2527,9 @@ impl CollabPanel {
|
||||
|
||||
let button = match section {
|
||||
Section::ActiveCall => channel_link.map(|channel_link| {
|
||||
let channel_link_copy = channel_link;
|
||||
IconButton::new("channel-link", IconName::Copy)
|
||||
.icon_size(IconSize::Small)
|
||||
.size(ButtonSize::None)
|
||||
CopyButton::new(channel_link)
|
||||
.visible_on_hover("section-header")
|
||||
.on_click(move |_, _, cx| {
|
||||
let item = ClipboardItem::new_string(channel_link_copy.clone());
|
||||
cx.write_to_clipboard(item)
|
||||
})
|
||||
.tooltip(Tooltip::text("Copy channel link"))
|
||||
.tooltip_label("Copy Channel Link")
|
||||
.into_any_element()
|
||||
}),
|
||||
Section::Contacts => Some(
|
||||
|
||||
@@ -1579,8 +1579,10 @@ impl Panel for DebugPanel {
|
||||
Some(proto::PanelId::DebugPanel)
|
||||
}
|
||||
|
||||
fn icon(&self, _window: &Window, _cx: &App) -> Option<IconName> {
|
||||
Some(IconName::Debug)
|
||||
fn icon(&self, _window: &Window, cx: &App) -> Option<IconName> {
|
||||
DebuggerSettings::get_global(cx)
|
||||
.button
|
||||
.then_some(IconName::Debug)
|
||||
}
|
||||
|
||||
fn icon_tooltip(&self, _window: &Window, cx: &App) -> Option<&'static str> {
|
||||
|
||||
@@ -19,6 +19,7 @@ ai_onboarding.workspace = true
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
@@ -52,7 +53,10 @@ settings.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
391
crates/edit_prediction/src/capture_example.rs
Normal file
391
crates/edit_prediction/src/capture_example.rs
Normal file
@@ -0,0 +1,391 @@
|
||||
use crate::{
|
||||
EditPredictionStore, StoredEvent,
|
||||
cursor_excerpt::editable_and_context_ranges_for_cursor_position, example_spec::ExampleSpec,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use buffer_diff::BufferDiffSnapshot;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{Buffer, ToPoint as _};
|
||||
use project::{Project, WorktreeId};
|
||||
use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
|
||||
use text::BufferSnapshot as TextBufferSnapshot;
|
||||
|
||||
pub fn capture_example(
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_anchor: language::Anchor,
|
||||
cx: &mut App,
|
||||
) -> Option<Task<Result<ExampleSpec>>> {
|
||||
let ep_store = EditPredictionStore::try_global(cx)?;
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let file = snapshot.file()?;
|
||||
let worktree_id = file.worktree_id(cx);
|
||||
let repository = project.read(cx).active_repository(cx)?;
|
||||
let repository_snapshot = repository.read(cx).snapshot();
|
||||
let worktree = project.read(cx).worktree_for_id(worktree_id, cx)?;
|
||||
let cursor_path = worktree.read(cx).root_name().join(file.path());
|
||||
if worktree.read(cx).abs_path() != repository_snapshot.work_directory_abs_path {
|
||||
return None;
|
||||
}
|
||||
|
||||
let repository_url = repository_snapshot
|
||||
.remote_origin_url
|
||||
.clone()
|
||||
.or_else(|| repository_snapshot.remote_upstream_url.clone())?;
|
||||
let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
|
||||
|
||||
let mut events = ep_store.update(cx, |store, cx| {
|
||||
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
|
||||
});
|
||||
|
||||
let git_store = project.read(cx).git_store().clone();
|
||||
|
||||
Some(cx.spawn(async move |mut cx| {
|
||||
let snapshots_by_path =
|
||||
collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
|
||||
|
||||
events.retain(|stored_event| {
|
||||
match stored_event.event.as_ref() {
|
||||
zeta_prompt::Event::BufferChange { path, .. } => {
|
||||
if !snapshots_by_path.contains_key(path) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
});
|
||||
|
||||
let line_comment_prefix = snapshot
|
||||
.language()
|
||||
.and_then(|lang| lang.config().line_comments.first())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_default();
|
||||
let (cursor_excerpt, cursor_offset) = cx
|
||||
.background_executor()
|
||||
.spawn(async move { compute_cursor_excerpt(&snapshot, cursor_anchor) })
|
||||
.await;
|
||||
let uncommitted_diff = cx
|
||||
.background_executor()
|
||||
.spawn(async move { compute_uncommitted_diff(snapshots_by_path) })
|
||||
.await;
|
||||
|
||||
let mut edit_history = String::new();
|
||||
for stored_event in &events {
|
||||
zeta_prompt::write_event(&mut edit_history, &stored_event.event);
|
||||
if !edit_history.ends_with('\n') {
|
||||
edit_history.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
let mut spec = ExampleSpec {
|
||||
name: generate_timestamp_name(),
|
||||
repository_url,
|
||||
revision,
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff,
|
||||
cursor_path: cursor_path.as_std_path().into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history,
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
spec.set_cursor_excerpt(&cursor_excerpt, cursor_offset, &line_comment_prefix);
|
||||
Ok(spec)
|
||||
}))
|
||||
}
|
||||
|
||||
fn compute_cursor_excerpt(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
cursor_anchor: language::Anchor,
|
||||
) -> (String, usize) {
|
||||
use text::ToOffset as _;
|
||||
|
||||
let cursor_point = cursor_anchor.to_point(snapshot);
|
||||
let (_editable_range, context_range) =
|
||||
editable_and_context_ranges_for_cursor_position(cursor_point, snapshot, 100, 50);
|
||||
let context_start_offset = context_range.start.to_offset(snapshot);
|
||||
let cursor_offset = cursor_anchor.to_offset(snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
|
||||
let excerpt = snapshot.text_for_range(context_range).collect::<String>();
|
||||
(excerpt, cursor_offset_in_excerpt)
|
||||
}
|
||||
|
||||
async fn collect_snapshots(
|
||||
project: &Entity<Project>,
|
||||
git_store: &Entity<project::git_store::GitStore>,
|
||||
worktree_id: WorktreeId,
|
||||
events: &[StoredEvent],
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
|
||||
let mut snapshots_by_path = HashMap::default();
|
||||
let root_name = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.worktree_for_id(worktree_id, cx)
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.root_name()
|
||||
.to_owned()
|
||||
})?;
|
||||
for stored_event in events {
|
||||
let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
|
||||
if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(path, cx)
|
||||
.filter(|path| path.worktree_id == worktree_id)?;
|
||||
let full_path = root_name.join(&project_path.path).as_std_path().into();
|
||||
Some((project_path, full_path))
|
||||
})? {
|
||||
if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})?
|
||||
.await?;
|
||||
let diff = git_store
|
||||
.update(cx, |git_store, cx| {
|
||||
git_store.open_uncommitted_diff(buffer.clone(), cx)
|
||||
})?
|
||||
.await?;
|
||||
let diff_snapshot = diff.update(cx, |diff, cx| diff.snapshot(cx))?;
|
||||
entry.insert((stored_event.old_snapshot.clone(), diff_snapshot));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(snapshots_by_path)
|
||||
}
|
||||
|
||||
fn compute_uncommitted_diff(
|
||||
snapshots_by_path: HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>,
|
||||
) -> String {
|
||||
let mut uncommitted_diff = String::new();
|
||||
for (full_path, (before_text, diff_snapshot)) in snapshots_by_path {
|
||||
if let Some(head_text) = &diff_snapshot.base_text_string() {
|
||||
let file_diff = language::unified_diff(head_text, &before_text.text());
|
||||
if !file_diff.is_empty() {
|
||||
let path_str = full_path.to_string_lossy();
|
||||
writeln!(uncommitted_diff, "--- a/{path_str}").ok();
|
||||
writeln!(uncommitted_diff, "+++ b/{path_str}").ok();
|
||||
uncommitted_diff.push_str(&file_diff);
|
||||
if !uncommitted_diff.ends_with('\n') {
|
||||
uncommitted_diff.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
uncommitted_diff
|
||||
}
|
||||
|
||||
fn generate_timestamp_name() -> String {
|
||||
let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
|
||||
match format {
|
||||
Ok(format) => {
|
||||
let now = time::OffsetDateTime::now_local()
|
||||
.unwrap_or_else(|_| time::OffsetDateTime::now_utc());
|
||||
now.format(&format)
|
||||
.unwrap_or_else(|_| "unknown-time".to_string())
|
||||
}
|
||||
Err(_) => "unknown-time".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::{Client, UserStore};
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::{AppContext as _, TestAppContext, http_client::FakeHttpClient};
|
||||
use indoc::indoc;
|
||||
use language::{Anchor, Point};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::path::Path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_capture_example(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
let committed_contents = indoc! {"
|
||||
fn main() {
|
||||
one();
|
||||
two();
|
||||
three();
|
||||
four();
|
||||
five();
|
||||
six();
|
||||
seven();
|
||||
eight();
|
||||
nine();
|
||||
}
|
||||
"};
|
||||
|
||||
let disk_contents = indoc! {"
|
||||
fn main() {
|
||||
// comment 1
|
||||
one();
|
||||
two();
|
||||
three();
|
||||
four();
|
||||
five();
|
||||
six();
|
||||
seven();
|
||||
eight();
|
||||
// comment 2
|
||||
nine();
|
||||
}
|
||||
"};
|
||||
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
".git": {},
|
||||
"src": {
|
||||
"main.rs": disk_contents,
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
fs.set_head_for_repo(
|
||||
Path::new("/project/.git"),
|
||||
&[("src/main.rs", committed_contents.to_string())],
|
||||
"abc123def456",
|
||||
);
|
||||
fs.set_remote_for_repo(
|
||||
Path::new("/project/.git"),
|
||||
"origin",
|
||||
"https://github.com/test/repo.git",
|
||||
);
|
||||
|
||||
let project = Project::test(fs.clone(), ["/project".as_ref()], cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer("/project/src/main.rs", cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ep_store = cx.read(|cx| EditPredictionStore::try_global(cx).unwrap());
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_buffer(&buffer, &project, cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let point = Point::new(6, 0);
|
||||
buffer.edit([(point..point, " // comment 3\n")], None, cx);
|
||||
let point = Point::new(4, 0);
|
||||
buffer.edit([(point..point, " // comment 4\n")], None, cx);
|
||||
|
||||
pretty_assertions::assert_eq!(
|
||||
buffer.text(),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
// comment 1
|
||||
one();
|
||||
two();
|
||||
// comment 4
|
||||
three();
|
||||
four();
|
||||
// comment 3
|
||||
five();
|
||||
six();
|
||||
seven();
|
||||
eight();
|
||||
// comment 2
|
||||
nine();
|
||||
}
|
||||
"}
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let mut example = cx
|
||||
.update(|cx| capture_example(project.clone(), buffer.clone(), Anchor::MIN, cx).unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
example.name = "test".to_string();
|
||||
|
||||
pretty_assertions::assert_eq!(
|
||||
example,
|
||||
ExampleSpec {
|
||||
name: "test".to_string(),
|
||||
repository_url: "https://github.com/test/repo.git".to_string(),
|
||||
revision: "abc123def456".to_string(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: indoc! {"
|
||||
--- a/project/src/main.rs
|
||||
+++ b/project/src/main.rs
|
||||
@@ -1,4 +1,5 @@
|
||||
fn main() {
|
||||
+ // comment 1
|
||||
one();
|
||||
two();
|
||||
three();
|
||||
@@ -7,5 +8,6 @@
|
||||
six();
|
||||
seven();
|
||||
eight();
|
||||
+ // comment 2
|
||||
nine();
|
||||
}
|
||||
"}
|
||||
.to_string(),
|
||||
cursor_path: Path::new("project/src/main.rs").into(),
|
||||
cursor_position: indoc! {"
|
||||
fn main() {
|
||||
^[CURSOR_POSITION]
|
||||
// comment 1
|
||||
one();
|
||||
two();
|
||||
// comment 4
|
||||
three();
|
||||
four();
|
||||
// comment 3
|
||||
five();
|
||||
six();
|
||||
seven();
|
||||
eight();
|
||||
// comment 2
|
||||
nine();
|
||||
}
|
||||
"}
|
||||
.to_string(),
|
||||
edit_history: indoc! {"
|
||||
--- a/project/src/main.rs
|
||||
+++ b/project/src/main.rs
|
||||
@@ -2,8 +2,10 @@
|
||||
// comment 1
|
||||
one();
|
||||
two();
|
||||
+ // comment 4
|
||||
three();
|
||||
four();
|
||||
+ // comment 3
|
||||
five();
|
||||
six();
|
||||
seven();
|
||||
"}
|
||||
.to_string(),
|
||||
expected_patches: Vec::new()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
zlog::init_test();
|
||||
let http_client = FakeHttpClient::with_404_response();
|
||||
let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
EditPredictionStore::global(&client, &user_store, cx);
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -35,6 +35,7 @@ use semver::Version;
|
||||
use serde::de::DeserializeOwned;
|
||||
use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
|
||||
use std::collections::{VecDeque, hash_map};
|
||||
use text::Edit;
|
||||
use workspace::Workspace;
|
||||
|
||||
use std::ops::Range;
|
||||
@@ -57,9 +58,9 @@ pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
|
||||
#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
|
||||
pub mod udiff;
|
||||
|
||||
mod capture_example;
|
||||
mod zed_edit_prediction_delegate;
|
||||
pub mod zeta1;
|
||||
pub mod zeta2;
|
||||
@@ -74,6 +75,7 @@ pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
pub use crate::sweep_ai::SweepAi;
|
||||
pub use capture_example::capture_example;
|
||||
pub use language_model::ApiKeyState;
|
||||
pub use telemetry_events::EditPredictionRating;
|
||||
pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
|
||||
@@ -231,8 +233,15 @@ pub struct EditPredictionFinishedDebugEvent {
|
||||
|
||||
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
|
||||
|
||||
/// An event with associated metadata for reconstructing buffer state.
|
||||
#[derive(Clone)]
|
||||
pub struct StoredEvent {
|
||||
pub event: Arc<zeta_prompt::Event>,
|
||||
pub old_snapshot: TextBufferSnapshot,
|
||||
}
|
||||
|
||||
struct ProjectState {
|
||||
events: VecDeque<Arc<zeta_prompt::Event>>,
|
||||
events: VecDeque<StoredEvent>,
|
||||
last_event: Option<LastEvent>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
@@ -248,7 +257,7 @@ struct ProjectState {
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
|
||||
self.events
|
||||
.iter()
|
||||
.cloned()
|
||||
@@ -260,7 +269,7 @@ impl ProjectState {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
|
||||
self.events
|
||||
.iter()
|
||||
.cloned()
|
||||
@@ -415,7 +424,7 @@ impl LastEvent {
|
||||
&self,
|
||||
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
||||
cx: &App,
|
||||
) -> Option<Arc<zeta_prompt::Event>> {
|
||||
) -> Option<StoredEvent> {
|
||||
let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
|
||||
let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
|
||||
|
||||
@@ -430,19 +439,22 @@ impl LastEvent {
|
||||
})
|
||||
});
|
||||
|
||||
let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
|
||||
let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
|
||||
|
||||
if path == old_path && diff.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
in_open_source_repo,
|
||||
// TODO: Actually detect if this edit was predicted or not
|
||||
predicted: false,
|
||||
}))
|
||||
Some(StoredEvent {
|
||||
event: Arc::new(zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
in_open_source_repo,
|
||||
// TODO: Actually detect if this edit was predicted or not
|
||||
predicted: false,
|
||||
}),
|
||||
old_snapshot: self.old_snapshot.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,6 +487,52 @@ impl LastEvent {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn compute_diff_between_snapshots(
|
||||
old_snapshot: &TextBufferSnapshot,
|
||||
new_snapshot: &TextBufferSnapshot,
|
||||
) -> Option<String> {
|
||||
let edits: Vec<Edit<usize>> = new_snapshot
|
||||
.edits_since::<usize>(&old_snapshot.version)
|
||||
.collect();
|
||||
|
||||
let (first_edit, last_edit) = edits.first().zip(edits.last())?;
|
||||
|
||||
let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
|
||||
let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
|
||||
let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
|
||||
let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
|
||||
|
||||
const CONTEXT_LINES: u32 = 3;
|
||||
|
||||
let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
|
||||
let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
|
||||
let old_context_end_row =
|
||||
(old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
|
||||
let new_context_end_row =
|
||||
(new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
|
||||
|
||||
let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
|
||||
let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
|
||||
let old_end_line_offset = old_snapshot
|
||||
.point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
|
||||
let new_end_line_offset = new_snapshot
|
||||
.point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
|
||||
let old_edit_range = old_start_line_offset..old_end_line_offset;
|
||||
let new_edit_range = new_start_line_offset..new_end_line_offset;
|
||||
|
||||
let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
|
||||
let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
|
||||
|
||||
let diff = language::unified_diff_with_offsets(
|
||||
&old_region_text,
|
||||
&new_region_text,
|
||||
old_context_start_row,
|
||||
new_context_start_row,
|
||||
);
|
||||
|
||||
Some(diff)
|
||||
}
|
||||
|
||||
fn buffer_path_with_id_fallback(
|
||||
file: Option<&Arc<dyn File>>,
|
||||
snapshot: &TextBufferSnapshot,
|
||||
@@ -630,12 +688,14 @@ impl EditPredictionStore {
|
||||
pub fn clear_history(&mut self) {
|
||||
for project_state in self.projects.values_mut() {
|
||||
project_state.events.clear();
|
||||
project_state.last_event.take();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
|
||||
if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
|
||||
project_state.events.clear();
|
||||
project_state.last_event.take();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -643,7 +703,7 @@ impl EditPredictionStore {
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
) -> Vec<StoredEvent> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events(cx))
|
||||
@@ -654,7 +714,7 @@ impl EditPredictionStore {
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
) -> Vec<StoredEvent> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events_split_by_pause(cx))
|
||||
@@ -1536,8 +1596,10 @@ impl EditPredictionStore {
|
||||
|
||||
self.get_or_init_project(&project, cx);
|
||||
let project_state = self.projects.get(&project.entity_id()).unwrap();
|
||||
let events = project_state.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
let stored_events = project_state.events(cx);
|
||||
let has_events = !stored_events.is_empty();
|
||||
let events: Vec<Arc<zeta_prompt::Event>> =
|
||||
stored_events.into_iter().map(|e| e.event).collect();
|
||||
let debug_tx = project_state.debug_tx.clone();
|
||||
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
@@ -1984,7 +2046,9 @@ impl EditPredictionStore {
|
||||
"Edit Prediction Rated",
|
||||
rating,
|
||||
inputs = prediction.inputs,
|
||||
output = prediction.edit_preview.as_unified_diff(&prediction.edits),
|
||||
output = prediction
|
||||
.edit_preview
|
||||
.as_unified_diff(prediction.snapshot.file(), &prediction.edits),
|
||||
feedback
|
||||
);
|
||||
self.client.telemetry().flush_events().detach();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
|
||||
use crate::{compute_diff_between_snapshots, udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
|
||||
use client::{UserStore, test::FakeServer};
|
||||
use clock::{FakeSystemClock, ReplicaId};
|
||||
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
|
||||
@@ -360,7 +360,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
|
||||
ep_store.edit_history_for_project(&project, cx)
|
||||
});
|
||||
assert_eq!(events.len(), 1);
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
|
||||
assert_eq!(
|
||||
diff.as_str(),
|
||||
indoc! {"
|
||||
@@ -377,7 +377,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
|
||||
ep_store.edit_history_for_project_with_pause_split_last_event(&project, cx)
|
||||
});
|
||||
assert_eq!(events.len(), 2);
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].as_ref();
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[0].event.as_ref();
|
||||
assert_eq!(
|
||||
diff.as_str(),
|
||||
indoc! {"
|
||||
@@ -389,7 +389,7 @@ async fn test_edit_history_getter_pause_splits_last_event(cx: &mut TestAppContex
|
||||
"}
|
||||
);
|
||||
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].as_ref();
|
||||
let zeta_prompt::Event::BufferChange { diff, .. } = events[1].event.as_ref();
|
||||
assert_eq!(
|
||||
diff.as_str(),
|
||||
indoc! {"
|
||||
@@ -2082,6 +2082,74 @@ async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut Te
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| {
|
||||
Buffer::local(
|
||||
indoc! {"
|
||||
zero
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
six
|
||||
seven
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
eleven
|
||||
twelve
|
||||
thirteen
|
||||
fourteen
|
||||
fifteen
|
||||
sixteen
|
||||
seventeen
|
||||
eighteen
|
||||
nineteen
|
||||
twenty
|
||||
twenty-one
|
||||
twenty-two
|
||||
twenty-three
|
||||
twenty-four
|
||||
"},
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let point = Point::new(12, 0);
|
||||
buffer.edit([(point..point, "SECOND INSERTION\n")], None, cx);
|
||||
let point = Point::new(8, 0);
|
||||
buffer.edit([(point..point, "FIRST INSERTION\n")], None, cx);
|
||||
});
|
||||
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
|
||||
|
||||
let diff = compute_diff_between_snapshots(&old_snapshot, &new_snapshot).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
diff,
|
||||
indoc! {"
|
||||
@@ -6,10 +6,12 @@
|
||||
five
|
||||
six
|
||||
seven
|
||||
+FIRST INSERTION
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
eleven
|
||||
+SECOND INSERTION
|
||||
twelve
|
||||
thirteen
|
||||
fourteen
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
zlog::init_test();
|
||||
|
||||
@@ -1,39 +1,90 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write as _, mem, path::Path, sync::Arc};
|
||||
use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
|
||||
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ExampleSpec {
|
||||
#[serde(default)]
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tags: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<String>,
|
||||
#[serde(default)]
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
pub expected_patches: Vec<String>,
|
||||
}
|
||||
|
||||
const REASONING_HEADING: &str = "Reasoning";
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct FrontMatter<'a> {
|
||||
repository_url: Cow<'a, str>,
|
||||
revision: Cow<'a, str>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tags: Vec<String>,
|
||||
}
|
||||
|
||||
impl ExampleSpec {
|
||||
/// Generate a sanitized filename for this example.
|
||||
pub fn filename(&self) -> String {
|
||||
self.name
|
||||
.chars()
|
||||
.map(|c| match c {
|
||||
' ' | ':' | '~' | '^' | '?' | '*' | '[' | '\\' | '@' | '{' | '/' | '<' | '>'
|
||||
| '|' | '"' => '-',
|
||||
c => c,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Format this example spec as markdown.
|
||||
pub fn to_markdown(&self) -> String {
|
||||
use std::fmt::Write as _;
|
||||
|
||||
let front_matter = FrontMatter {
|
||||
repository_url: Cow::Borrowed(&self.repository_url),
|
||||
revision: Cow::Borrowed(&self.revision),
|
||||
tags: self.tags.clone(),
|
||||
};
|
||||
let front_matter_toml =
|
||||
toml::to_string_pretty(&front_matter).unwrap_or_else(|_| String::new());
|
||||
|
||||
let mut markdown = String::new();
|
||||
|
||||
_ = writeln!(markdown, "+++");
|
||||
markdown.push_str(&front_matter_toml);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "+++");
|
||||
markdown.push('\n');
|
||||
|
||||
_ = writeln!(markdown, "# {}", self.name);
|
||||
markdown.push('\n');
|
||||
|
||||
_ = writeln!(markdown, "repository_url = {}", self.repository_url);
|
||||
_ = writeln!(markdown, "revision = {}", self.revision);
|
||||
markdown.push('\n');
|
||||
if let Some(reasoning) = &self.reasoning {
|
||||
_ = writeln!(markdown, "## {}", REASONING_HEADING);
|
||||
markdown.push('\n');
|
||||
markdown.push_str(reasoning);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
markdown.push('\n');
|
||||
}
|
||||
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
_ = writeln!(markdown, "## {}", UNCOMMITTED_DIFF_HEADING);
|
||||
@@ -75,34 +126,48 @@ impl ExampleSpec {
|
||||
|
||||
_ = writeln!(markdown, "## {}", EXPECTED_PATCH_HEADING);
|
||||
markdown.push('\n');
|
||||
_ = writeln!(markdown, "```diff");
|
||||
markdown.push_str(&self.expected_patch);
|
||||
if !markdown.ends_with('\n') {
|
||||
for patch in &self.expected_patches {
|
||||
_ = writeln!(markdown, "```diff");
|
||||
markdown.push_str(patch);
|
||||
if !markdown.ends_with('\n') {
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "```");
|
||||
markdown.push('\n');
|
||||
}
|
||||
_ = writeln!(markdown, "```");
|
||||
markdown.push('\n');
|
||||
|
||||
markdown
|
||||
}
|
||||
|
||||
/// Parse an example spec from markdown.
|
||||
pub fn from_markdown(name: String, input: &str) -> anyhow::Result<Self> {
|
||||
pub fn from_markdown(mut input: &str) -> anyhow::Result<Self> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
let parser = Parser::new(input);
|
||||
|
||||
let mut spec = ExampleSpec {
|
||||
name,
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
if let Some(rest) = input.strip_prefix("+++\n")
|
||||
&& let Some((front_matter, rest)) = rest.split_once("+++\n")
|
||||
{
|
||||
if let Ok(data) = toml::from_str::<FrontMatter<'_>>(front_matter) {
|
||||
spec.repository_url = data.repository_url.into_owned();
|
||||
spec.revision = data.revision.into_owned();
|
||||
spec.tags = data.tags;
|
||||
}
|
||||
input = rest.trim_start();
|
||||
}
|
||||
|
||||
let parser = Parser::new(input);
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
@@ -123,20 +188,9 @@ impl ExampleSpec {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if let Section::Start = current_section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
spec.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
spec.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
spec.name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
@@ -194,7 +248,7 @@ impl ExampleSpec {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
spec.expected_patch = mem::take(&mut text);
|
||||
spec.expected_patches.push(mem::take(&mut text));
|
||||
}
|
||||
Section::Start | Section::Other => {}
|
||||
}
|
||||
@@ -209,4 +263,326 @@ impl ExampleSpec {
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
/// Returns the excerpt of text around the cursor, and the offset of the cursor within that
|
||||
/// excerpt.
|
||||
///
|
||||
/// The cursor's position is marked with a special comment that appears
|
||||
/// below the cursor line, which contains the string `[CURSOR_POSITION]`,
|
||||
/// preceded by an arrow marking the cursor's column. The arrow can be
|
||||
/// either:
|
||||
/// - `^` - The cursor column is at the position of the `^` character (pointing up to the cursor)
|
||||
/// - `<` - The cursor column is at the first non-whitespace character on that line.
|
||||
pub fn cursor_excerpt(&self) -> Result<(String, usize)> {
|
||||
let input = &self.cursor_position;
|
||||
|
||||
// Check for inline cursor marker first
|
||||
if let Some(inline_offset) = input.find(INLINE_CURSOR_MARKER) {
|
||||
let excerpt = input[..inline_offset].to_string()
|
||||
+ &input[inline_offset + INLINE_CURSOR_MARKER.len()..];
|
||||
return Ok((excerpt, inline_offset));
|
||||
}
|
||||
|
||||
let marker_offset = input
|
||||
.find(CURSOR_POSITION_MARKER)
|
||||
.context("missing [CURSOR_POSITION] marker")?;
|
||||
let marker_line_start = input[..marker_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let marker_line_end = input[marker_line_start..]
|
||||
.find('\n')
|
||||
.map(|pos| marker_line_start + pos + 1)
|
||||
.unwrap_or(input.len());
|
||||
let marker_line = &input[marker_line_start..marker_line_end].trim_end_matches('\n');
|
||||
|
||||
let cursor_column = if let Some(cursor_offset) = marker_line.find('^') {
|
||||
cursor_offset
|
||||
} else if let Some(less_than_pos) = marker_line.find('<') {
|
||||
marker_line
|
||||
.find(|c: char| !c.is_whitespace())
|
||||
.unwrap_or(less_than_pos)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"cursor position marker line must contain '^' or '<' before [CURSOR_POSITION]"
|
||||
);
|
||||
};
|
||||
|
||||
let mut excerpt = input[..marker_line_start].to_string() + &input[marker_line_end..];
|
||||
excerpt.truncate(excerpt.trim_end_matches('\n').len());
|
||||
|
||||
// The cursor is on the line above the marker line.
|
||||
let cursor_line_end = marker_line_start.saturating_sub(1);
|
||||
let cursor_line_start = excerpt[..cursor_line_end]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_offset = cursor_line_start + cursor_column;
|
||||
|
||||
Ok((excerpt, cursor_offset))
|
||||
}
|
||||
|
||||
/// Sets the cursor position excerpt from a plain excerpt and cursor byte offset.
|
||||
///
|
||||
/// The `line_comment_prefix` is used to format the marker line as a comment.
|
||||
/// If the cursor column is less than the comment prefix length, the `<` format is used.
|
||||
/// Otherwise, the `^` format is used.
|
||||
pub fn set_cursor_excerpt(
|
||||
&mut self,
|
||||
excerpt: &str,
|
||||
cursor_offset: usize,
|
||||
line_comment_prefix: &str,
|
||||
) {
|
||||
// Find which line the cursor is on and its column
|
||||
let cursor_line_start = excerpt[..cursor_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_line_end = excerpt[cursor_line_start..]
|
||||
.find('\n')
|
||||
.map(|pos| cursor_line_start + pos + 1)
|
||||
.unwrap_or(excerpt.len());
|
||||
let cursor_line = &excerpt[cursor_line_start..cursor_line_end];
|
||||
let cursor_line_indent = &cursor_line[..cursor_line.len() - cursor_line.trim_start().len()];
|
||||
let cursor_column = cursor_offset - cursor_line_start;
|
||||
|
||||
// Build the marker line
|
||||
let mut marker_line = String::new();
|
||||
if cursor_column < line_comment_prefix.len() {
|
||||
for _ in 0..cursor_column {
|
||||
marker_line.push(' ');
|
||||
}
|
||||
marker_line.push_str(line_comment_prefix);
|
||||
write!(marker_line, " <{}", CURSOR_POSITION_MARKER).unwrap();
|
||||
} else {
|
||||
if cursor_column >= cursor_line_indent.len() + line_comment_prefix.len() {
|
||||
marker_line.push_str(cursor_line_indent);
|
||||
}
|
||||
marker_line.push_str(line_comment_prefix);
|
||||
while marker_line.len() < cursor_column {
|
||||
marker_line.push(' ');
|
||||
}
|
||||
write!(marker_line, "^{}", CURSOR_POSITION_MARKER).unwrap();
|
||||
}
|
||||
|
||||
// Build the final cursor_position string
|
||||
let mut result = String::with_capacity(excerpt.len() + marker_line.len() + 2);
|
||||
result.push_str(&excerpt[..cursor_line_end]);
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str(&marker_line);
|
||||
if cursor_line_end < excerpt.len() {
|
||||
result.push('\n');
|
||||
result.push_str(&excerpt[cursor_line_end..]);
|
||||
}
|
||||
|
||||
self.cursor_position = result;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
|
||||
#[test]
|
||||
fn test_cursor_excerpt_with_caret() {
|
||||
let mut spec = ExampleSpec {
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("test.rs").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
// Cursor before `42`
|
||||
let excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
};
|
||||
let offset = excerpt.find("42").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor after `l` in `let`
|
||||
let offset = excerpt.find("et x").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor before `let`
|
||||
let offset = excerpt.find("let").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor at beginning of the line with `let`
|
||||
let offset = excerpt.find(" let").unwrap();
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// <[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Cursor at end of line, after the semicolon
|
||||
let offset = excerpt.find(';').unwrap() + 1;
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
|
||||
// Caret at end of file (no trailing newline)
|
||||
let excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;"
|
||||
};
|
||||
let offset = excerpt.find(';').unwrap() + 1;
|
||||
let position_string = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
// ^[CURSOR_POSITION]"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
spec.set_cursor_excerpt(excerpt, offset, "//");
|
||||
assert_eq!(spec.cursor_position, position_string);
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(excerpt.to_string(), offset)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cursor_excerpt_with_inline_marker() {
|
||||
let mut spec = ExampleSpec {
|
||||
name: String::new(),
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
tags: Vec::new(),
|
||||
reasoning: None,
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Path::new("test.rs").into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patches: Vec::new(),
|
||||
};
|
||||
|
||||
// Cursor before `42` using inline marker
|
||||
spec.cursor_position = indoc! {"
|
||||
fn main() {
|
||||
let x = <|user_cursor|>42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let expected_excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
println!(\"{}\", x);
|
||||
}"
|
||||
};
|
||||
let expected_offset = expected_excerpt.find("42").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
|
||||
// Cursor at beginning of line
|
||||
spec.cursor_position = indoc! {"
|
||||
fn main() {
|
||||
<|user_cursor|> let x = 42;
|
||||
}"
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let expected_excerpt = indoc! {"
|
||||
fn main() {
|
||||
let x = 42;
|
||||
}"
|
||||
};
|
||||
let expected_offset = expected_excerpt.find(" let").unwrap();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
|
||||
// Cursor at end of file
|
||||
spec.cursor_position = "fn main() {}<|user_cursor|>".to_string();
|
||||
let expected_excerpt = "fn main() {}";
|
||||
let expected_offset = expected_excerpt.len();
|
||||
|
||||
assert_eq!(
|
||||
spec.cursor_excerpt().unwrap(),
|
||||
(expected_excerpt.to_string(), expected_offset)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,8 @@ use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::{Project, ProjectPath};
|
||||
use util::paths::PathStyle;
|
||||
use util::rel_path::RelPath;
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
|
||||
use project::Project;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
@@ -30,54 +28,26 @@ pub async fn apply_diff(
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
anyhow::Ok(
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("no worktrees")?
|
||||
.read(cx)
|
||||
.id(),
|
||||
)
|
||||
})??;
|
||||
|
||||
for line in diff_str.lines() {
|
||||
let diff_line = DiffLine::parse(line);
|
||||
|
||||
if let DiffLine::OldPath { path } = diff_line {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?.into_arc(),
|
||||
};
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
included_files.insert(path.to_string(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
let ranges = [Anchor::MIN..Anchor::MAX];
|
||||
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut current_file = None;
|
||||
let mut edits = vec![];
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk {
|
||||
path: file_path,
|
||||
hunk,
|
||||
} => {
|
||||
let (buffer, ranges) = match current_file {
|
||||
DiffEvent::Hunk { path, hunk } => {
|
||||
let buffer = match current_file {
|
||||
None => {
|
||||
let buffer = included_files
|
||||
.get_mut(file_path.as_ref())
|
||||
.expect("Opened all files in diff");
|
||||
|
||||
current_file = Some((buffer, ranges.as_slice()));
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project
|
||||
.find_project_path(path.as_ref(), cx)
|
||||
.context("no such path")?;
|
||||
anyhow::Ok(project.open_buffer(project_path, cx))
|
||||
})??
|
||||
.await?;
|
||||
included_files.insert(path.to_string(), buffer.clone());
|
||||
current_file = Some(buffer);
|
||||
current_file.as_ref().unwrap()
|
||||
}
|
||||
Some(ref current) => current,
|
||||
@@ -85,14 +55,14 @@ pub async fn apply_diff(
|
||||
|
||||
buffer.read_with(cx, |buffer, _| {
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges)
|
||||
resolve_hunk_edits_in_buffer(hunk, buffer, ranges.as_slice())
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
DiffEvent::FileEnd { renamed_to } => {
|
||||
let (buffer, _) = current_file
|
||||
let buffer = current_file
|
||||
.take()
|
||||
.context("Got a FileEnd event before an Hunk event")?;
|
||||
|
||||
@@ -128,10 +98,69 @@ pub async fn apply_diff(
|
||||
Ok(OpenedBuffers(included_files))
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
/// Extract the diff for a specific file from a multi-file diff.
|
||||
/// Returns an error if the file is not found in the diff.
|
||||
pub fn extract_file_diff(full_diff: &str, file_path: &str) -> Result<String> {
|
||||
let mut result = String::new();
|
||||
let mut in_target_file = false;
|
||||
let mut found_file = false;
|
||||
|
||||
for line in full_diff.lines() {
|
||||
if line.starts_with("diff --git") {
|
||||
if in_target_file {
|
||||
break;
|
||||
}
|
||||
in_target_file = line.contains(&format!("a/{}", file_path))
|
||||
|| line.contains(&format!("b/{}", file_path));
|
||||
if in_target_file {
|
||||
found_file = true;
|
||||
}
|
||||
}
|
||||
|
||||
if in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
if !found_file {
|
||||
anyhow::bail!("File '{}' not found in diff", file_path);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Strip unnecessary git metadata lines from a diff, keeping only the lines
|
||||
/// needed for patch application: path headers (--- and +++), hunk headers (@@),
|
||||
/// and content lines (+, -, space).
|
||||
pub fn strip_diff_metadata(diff: &str) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
for line in diff.lines() {
|
||||
let dominated = DiffLine::parse(line);
|
||||
match dominated {
|
||||
// Keep path headers, hunk headers, and content lines
|
||||
DiffLine::OldPath { .. }
|
||||
| DiffLine::NewPath { .. }
|
||||
| DiffLine::HunkHeader(_)
|
||||
| DiffLine::Context(_)
|
||||
| DiffLine::Deletion(_)
|
||||
| DiffLine::Addition(_) => {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
// Skip garbage lines (diff --git, index, etc.)
|
||||
DiffLine::Garbage(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(original: &str, diff_str: &str) -> Result<String> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
|
||||
let mut text = text.to_string();
|
||||
let mut text = original.to_string();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
@@ -151,6 +180,51 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
/// Returns the individual edits that would be applied by a diff to the given content.
|
||||
/// Each edit is a tuple of (byte_range_in_content, replacement_text).
|
||||
/// Uses sub-line diffing to find the precise character positions of changes.
|
||||
/// Returns an empty vec if the hunk context is not found or is ambiguous.
|
||||
pub fn edits_for_diff(content: &str, diff_str: &str) -> Result<Vec<(Range<usize>, String)>> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut result = Vec::new();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
if hunk.context.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Find the context in the content
|
||||
let first_match = content.find(&hunk.context);
|
||||
let Some(context_offset) = first_match else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
// Check for ambiguity - if context appears more than once, reject
|
||||
if content[context_offset + 1..].contains(&hunk.context) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Use sub-line diffing to find precise edit positions
|
||||
for edit in &hunk.edits {
|
||||
let old_text = &content
|
||||
[context_offset + edit.range.start..context_offset + edit.range.end];
|
||||
let edits_within_hunk = text_diff(old_text, &edit.text);
|
||||
for (inner_range, inner_text) in edits_within_hunk {
|
||||
let absolute_start = context_offset + edit.range.start + inner_range.start;
|
||||
let absolute_end = context_offset + edit.range.start + inner_range.end;
|
||||
result.push((absolute_start..absolute_end, inner_text.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
DiffEvent::FileEnd { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
struct PatchFile<'a> {
|
||||
old_path: Cow<'a, str>,
|
||||
new_path: Cow<'a, str>,
|
||||
@@ -873,4 +947,135 @@ mod tests {
|
||||
|
||||
FakeFs::new(cx.background_executor.clone())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_file_diff() {
|
||||
let multi_file_diff = indoc! {r#"
|
||||
diff --git a/file1.txt b/file1.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file1.txt
|
||||
+++ b/file1.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
line1
|
||||
+added line
|
||||
line2
|
||||
line3
|
||||
diff --git a/file2.txt b/file2.txt
|
||||
index 2345678..bcdefgh 100644
|
||||
--- a/file2.txt
|
||||
+++ b/file2.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-old line
|
||||
+new line
|
||||
unchanged
|
||||
"#};
|
||||
|
||||
let file1_diff = extract_file_diff(multi_file_diff, "file1.txt").unwrap();
|
||||
assert_eq!(
|
||||
file1_diff,
|
||||
indoc! {r#"
|
||||
diff --git a/file1.txt b/file1.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file1.txt
|
||||
+++ b/file1.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
line1
|
||||
+added line
|
||||
line2
|
||||
line3
|
||||
"#}
|
||||
);
|
||||
|
||||
let file2_diff = extract_file_diff(multi_file_diff, "file2.txt").unwrap();
|
||||
assert_eq!(
|
||||
file2_diff,
|
||||
indoc! {r#"
|
||||
diff --git a/file2.txt b/file2.txt
|
||||
index 2345678..bcdefgh 100644
|
||||
--- a/file2.txt
|
||||
+++ b/file2.txt
|
||||
@@ -1,2 +1,2 @@
|
||||
-old line
|
||||
+new line
|
||||
unchanged
|
||||
"#}
|
||||
);
|
||||
|
||||
let result = extract_file_diff(multi_file_diff, "nonexistent.txt");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edits_for_diff() {
|
||||
let content = indoc! {"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
println!(\"{} {}\", x, y);
|
||||
}
|
||||
"};
|
||||
|
||||
let diff = indoc! {"
|
||||
--- a/file.rs
|
||||
+++ b/file.rs
|
||||
@@ -1,5 +1,5 @@
|
||||
fn main() {
|
||||
- let x = 1;
|
||||
+ let x = 42;
|
||||
let y = 2;
|
||||
println!(\"{} {}\", x, y);
|
||||
}
|
||||
"};
|
||||
|
||||
let edits = edits_for_diff(content, diff).unwrap();
|
||||
assert_eq!(edits.len(), 1);
|
||||
|
||||
let (range, replacement) = &edits[0];
|
||||
// With sub-line diffing, the edit should start at "1" (the actual changed character)
|
||||
let expected_start = content.find("let x = 1;").unwrap() + "let x = ".len();
|
||||
assert_eq!(range.start, expected_start);
|
||||
// The deleted text is just "1"
|
||||
assert_eq!(range.end, expected_start + "1".len());
|
||||
// The replacement text
|
||||
assert_eq!(replacement, "42");
|
||||
|
||||
// Verify the cursor would be positioned at the column of "1"
|
||||
let line_start = content[..range.start]
|
||||
.rfind('\n')
|
||||
.map(|p| p + 1)
|
||||
.unwrap_or(0);
|
||||
let cursor_column = range.start - line_start;
|
||||
// " let x = " is 12 characters, so column 12
|
||||
assert_eq!(cursor_column, " let x = ".len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_diff_metadata() {
|
||||
let diff_with_metadata = indoc! {r#"
|
||||
diff --git a/file.txt b/file.txt
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
context line
|
||||
-removed line
|
||||
+added line
|
||||
more context
|
||||
"#};
|
||||
|
||||
let stripped = strip_diff_metadata(diff_with_metadata);
|
||||
|
||||
assert_eq!(
|
||||
stripped,
|
||||
indoc! {r#"
|
||||
--- a/file.txt
|
||||
+++ b/file.txt
|
||||
@@ -1,3 +1,4 @@
|
||||
context line
|
||||
-removed line
|
||||
+added line
|
||||
more context
|
||||
"#}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use anthropic::{
|
||||
ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
|
||||
Response as AnthropicResponse, Role, non_streaming_completion,
|
||||
ANTHROPIC_API_URL, Event, Message, Request as AnthropicRequest, RequestContent,
|
||||
Response as AnthropicResponse, ResponseContent, Role, non_streaming_completion,
|
||||
stream_completion,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt as _;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use reqwest_client::ReqwestClient;
|
||||
@@ -15,12 +17,12 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_key: String,
|
||||
pub http_client: Arc<dyn HttpClient>,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new() -> Result<Self> {
|
||||
pub fn new() -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
@@ -30,7 +32,7 @@ impl PlainLlmClient {
|
||||
})
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
@@ -63,6 +65,72 @@ impl PlainLlmClient {
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn generate_streaming<F>(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
mut on_progress: F,
|
||||
) -> Result<AnthropicResponse>
|
||||
where
|
||||
F: FnMut(usize, &str),
|
||||
{
|
||||
let request = AnthropicRequest {
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
thinking: None,
|
||||
tool_choice: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
let mut stream = stream_completion(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
request,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
let mut response: Option<AnthropicResponse> = None;
|
||||
let mut text_content = String::new();
|
||||
|
||||
while let Some(event_result) = stream.next().await {
|
||||
let event = event_result.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
match event {
|
||||
Event::MessageStart { message } => {
|
||||
response = Some(message);
|
||||
}
|
||||
Event::ContentBlockDelta { delta, .. } => {
|
||||
if let anthropic::ContentDelta::TextDelta { text } = delta {
|
||||
text_content.push_str(&text);
|
||||
on_progress(text_content.len(), &text_content);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut response = response.ok_or_else(|| anyhow::anyhow!("No response received"))?;
|
||||
|
||||
if response.content.is_empty() && !text_content.is_empty() {
|
||||
response
|
||||
.content
|
||||
.push(ResponseContent::Text { text: text_content });
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BatchingLlmClient {
|
||||
@@ -408,6 +476,29 @@ impl AnthropicClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn generate_streaming<F>(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
on_progress: F,
|
||||
) -> Result<Option<AnthropicResponse>>
|
||||
where
|
||||
F: FnMut(usize, &str),
|
||||
{
|
||||
match self {
|
||||
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate_streaming(model, max_tokens, messages, on_progress)
|
||||
.await
|
||||
.map(Some),
|
||||
AnthropicClient::Batch(_) => {
|
||||
anyhow::bail!("Streaming not supported with batching client")
|
||||
}
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
AnthropicClient::Plain(_) => Ok(()),
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::Result;
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
let predictions = mem::take(&mut example.predictions)
|
||||
.into_iter()
|
||||
.map(|p| p.actual_patch)
|
||||
.collect();
|
||||
|
||||
example.spec.expected_patch = prediction.actual_patch;
|
||||
example.spec.expected_patches = predictions;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{PredictionProvider, PromptFormat, metrics::ClassificationMetrics};
|
||||
use crate::{PredictionProvider, PromptFormat};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::example_spec::ExampleSpec;
|
||||
@@ -87,7 +87,6 @@ pub struct ExamplePrediction {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
pub line_match: ClassificationMetrics,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
@@ -190,7 +189,11 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
.collect::<Vec<Example>>(),
|
||||
),
|
||||
"md" => {
|
||||
examples.push(parse_markdown_example(filename, &content).unwrap());
|
||||
let mut example = parse_markdown_example(&content).unwrap();
|
||||
if example.spec.name.is_empty() {
|
||||
example.spec.name = filename;
|
||||
}
|
||||
examples.push(example);
|
||||
}
|
||||
ext => {
|
||||
panic!("{} has invalid example extension `{ext}`", path.display())
|
||||
@@ -236,8 +239,8 @@ pub fn group_examples_by_repo(examples: &mut [Example]) -> Vec<Vec<&mut Example>
|
||||
examples_by_repo.into_values().collect()
|
||||
}
|
||||
|
||||
fn parse_markdown_example(name: String, input: &str) -> Result<Example> {
|
||||
let spec = ExampleSpec::from_markdown(name, input)?;
|
||||
fn parse_markdown_example(input: &str) -> Result<Example> {
|
||||
let spec = ExampleSpec::from_markdown(input)?;
|
||||
Ok(Example {
|
||||
spec,
|
||||
buffer: None,
|
||||
|
||||
@@ -30,7 +30,12 @@ pub async fn run_format_prompt(
|
||||
let prompt = TeacherPrompt::format_prompt(example);
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.spec.expected_patch.clone(), // TODO
|
||||
expected_output: example
|
||||
.spec
|
||||
.expected_patches
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_default(),
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
@@ -45,6 +50,11 @@ pub async fn run_format_prompt(
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
|
||||
let events = ep_store
|
||||
.edit_history_for_project(&project, cx)
|
||||
.into_iter()
|
||||
.map(|e| e.event)
|
||||
.collect();
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
@@ -53,7 +63,7 @@ pub async fn run_format_prompt(
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
ep_store.edit_history_for_project(&project, cx),
|
||||
events,
|
||||
example.spec.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
@@ -63,8 +73,15 @@ pub async fn run_format_prompt(
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output =
|
||||
zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
|
||||
let expected_output = zeta2_output_for_patch(
|
||||
&input,
|
||||
&example
|
||||
.spec
|
||||
.expected_patches
|
||||
.first()
|
||||
.context("expected patches is empty")?
|
||||
.clone(),
|
||||
)?;
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output,
|
||||
@@ -81,6 +98,7 @@ impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
@@ -176,13 +194,15 @@ impl TeacherPrompt {
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
|
||||
// TODO: control number of lines around cursor
|
||||
result.push_str(&example.spec.cursor_position);
|
||||
if !example.spec.cursor_position.ends_with('\n') {
|
||||
let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
|
||||
excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
|
||||
result.push_str(&excerpt);
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
|
||||
result.push_str("`````");
|
||||
result.push_str(Self::EDITABLE_REGION_END);
|
||||
result.push_str("\n`````");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
110
crates/edit_prediction_cli/src/git.rs
Normal file
110
crates/edit_prediction_cli/src/git.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use futures::lock::{Mutex, OwnedMutexGuard};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::paths::REPOS_DIR;
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
pub fn parse_repo_url(url: &str) -> Result<(String, String)> {
|
||||
if url.contains('@') {
|
||||
let (_, path) = url.split_once(':').context("expected : in git url")?;
|
||||
let (owner, repo) = path.split_once('/').context("expected / in git url")?;
|
||||
Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
|
||||
} else {
|
||||
let parsed = http_client::Url::parse(url)?;
|
||||
let mut segments = parsed.path_segments().context("empty http url")?;
|
||||
let owner = segments.next().context("expected owner")?;
|
||||
let repo = segments.next().context("expected repo")?;
|
||||
Ok((owner.to_string(), repo.trim_end_matches(".git").to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn repo_path_for_url(url: &str) -> Result<PathBuf> {
|
||||
let (owner, name) = parse_repo_url(url)?;
|
||||
Ok(REPOS_DIR.join(&owner).join(&name))
|
||||
}
|
||||
|
||||
pub async fn ensure_repo_cloned(repo_url: &str) -> Result<PathBuf> {
|
||||
let repo_path = repo_path_for_url(repo_url)?;
|
||||
let _lock = lock_repo(&repo_path).await;
|
||||
|
||||
if !repo_path.is_dir() {
|
||||
log::info!("Cloning {} into {:?}", repo_url, repo_path);
|
||||
std::fs::create_dir_all(&repo_path)?;
|
||||
run_git(&repo_path, &["init"]).await?;
|
||||
run_git(&repo_path, &["remote", "add", "origin", repo_url]).await?;
|
||||
}
|
||||
|
||||
// Always fetch to get latest commits
|
||||
run_git(&repo_path, &["fetch", "origin"]).await?;
|
||||
|
||||
// Check if we have a valid HEAD, if not checkout FETCH_HEAD
|
||||
let has_head = run_git(&repo_path, &["rev-parse", "HEAD"]).await.is_ok();
|
||||
if !has_head {
|
||||
// Use reset to set HEAD without needing a branch
|
||||
run_git(&repo_path, &["reset", "--hard", "FETCH_HEAD"]).await?;
|
||||
}
|
||||
|
||||
Ok(repo_path)
|
||||
}
|
||||
|
||||
pub async fn fetch_if_needed(repo_path: &Path, revision: &str) -> Result<String> {
|
||||
let resolved = run_git(
|
||||
repo_path,
|
||||
&["rev-parse", &format!("{}^{{commit}}", revision)],
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Ok(sha) = resolved {
|
||||
return Ok(sha);
|
||||
}
|
||||
|
||||
if run_git(repo_path, &["fetch", "--depth", "1", "origin", revision])
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(repo_path, &["fetch", "origin"]).await?;
|
||||
}
|
||||
|
||||
run_git(repo_path, &["rev-parse", "FETCH_HEAD"]).await
|
||||
}
|
||||
@@ -1,29 +1,19 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
git,
|
||||
headless::EpAppState,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
paths::WORKTREES_DIR,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use futures::AsyncWriteExt as _;
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use std::{fs, path::PathBuf, sync::Arc};
|
||||
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
@@ -86,37 +76,22 @@ async fn cursor_position(
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
|
||||
let cursor_path = RelPath::new(&example.spec.cursor_path, PathStyle::Posix)
|
||||
.context("Failed to create RelPath")?
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
let cursor_path = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.find_project_path(&example.spec.cursor_path, cx)
|
||||
})?
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to find cursor path {}",
|
||||
example.spec.cursor_path.display()
|
||||
)
|
||||
})?;
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = example
|
||||
.spec
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.context("missing cursor marker")?;
|
||||
let mut cursor_excerpt = example.spec.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
|
||||
let (cursor_excerpt, cursor_offset_within_excerpt) = example.spec.cursor_excerpt()?;
|
||||
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
@@ -212,17 +187,17 @@ async fn setup_project(
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_dir = git::repo_path_for_url(&example.spec.repository_url)?;
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
.join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
let repo_lock = git::lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
git::run_git(&repo_dir, &["init"]).await?;
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.spec.repository_url],
|
||||
)
|
||||
@@ -230,53 +205,26 @@ async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Resu
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"rev-parse",
|
||||
&format!("{}^{{commit}}", example.spec.revision),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
step_progress.set_substatus("fetching");
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.spec.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
step_progress.set_substatus("fetching");
|
||||
let revision = git::fetch_if_needed(&repo_dir, &example.spec.revision).await?;
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
git::run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
git::run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
git::run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
let branch_name = example.spec.filename();
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.spec.name, revision.as_str()],
|
||||
&["branch", "-f", &branch_name, revision.as_str()],
|
||||
)
|
||||
.await?;
|
||||
run_git(
|
||||
git::run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"worktree",
|
||||
"add",
|
||||
"-f",
|
||||
&worktree_path_string,
|
||||
&example.spec.name,
|
||||
],
|
||||
&["worktree", "add", "-f", &worktree_path_string, &branch_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -319,39 +267,3 @@ async fn apply_edit_history(
|
||||
) -> Result<OpenedBuffers> {
|
||||
edit_prediction::udiff::apply_diff(&example.spec.edit_history, project, cx).await
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ mod anthropic_client;
|
||||
mod distill;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod git;
|
||||
mod headless;
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
@@ -10,6 +11,7 @@ mod predict;
|
||||
mod progress;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
mod synthesize;
|
||||
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
@@ -28,6 +30,7 @@ use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
use crate::synthesize::{SynthesizeConfig, run_synthesize};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "ep")]
|
||||
@@ -67,6 +70,8 @@ enum Command {
|
||||
Distill,
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Generate eval examples by analyzing git commits from a repository
|
||||
Synthesize(SynthesizeArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
@@ -118,6 +123,9 @@ impl Display for Command {
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Synthesize(args) => {
|
||||
write!(f, "synthesize --repo={}", args.repo)
|
||||
}
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
@@ -143,7 +151,7 @@ struct PredictArgs {
|
||||
repetitions: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
|
||||
enum PredictionProvider {
|
||||
Sweep,
|
||||
Mercury,
|
||||
@@ -153,6 +161,29 @@ enum PredictionProvider {
|
||||
TeacherNonBatching,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct SynthesizeArgs {
|
||||
/// Repository URL (git@github.com:owner/repo or https://...)
|
||||
#[clap(long)]
|
||||
repo: String,
|
||||
|
||||
/// Number of examples to generate
|
||||
#[clap(long, default_value_t = 5)]
|
||||
count: usize,
|
||||
|
||||
/// Maximum commits to scan before giving up
|
||||
#[clap(long, default_value_t = 100)]
|
||||
max_commits: usize,
|
||||
|
||||
/// Only generate examples that require retrieved context to make a correct prediction
|
||||
#[clap(long)]
|
||||
require_context: bool,
|
||||
|
||||
/// Ignore state file and reprocess all commits
|
||||
#[clap(long)]
|
||||
fresh: bool,
|
||||
}
|
||||
|
||||
impl EpArgs {
|
||||
fn output_path(&self) -> Option<PathBuf> {
|
||||
if self.in_place {
|
||||
@@ -189,6 +220,26 @@ fn main() {
|
||||
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
|
||||
return;
|
||||
}
|
||||
Command::Synthesize(synth_args) => {
|
||||
let Some(output_dir) = args.output else {
|
||||
panic!("output dir is required");
|
||||
};
|
||||
let config = SynthesizeConfig {
|
||||
repo_url: synth_args.repo.clone(),
|
||||
count: synth_args.count,
|
||||
max_commits: synth_args.max_commits,
|
||||
output_dir,
|
||||
require_context: synth_args.require_context,
|
||||
fresh: synth_args.fresh,
|
||||
};
|
||||
smol::block_on(async {
|
||||
if let Err(e) = run_synthesize(config).await {
|
||||
eprintln!("Error: {:?}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -256,7 +307,7 @@ fn main() {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
Command::Clean | Command::Synthesize(_) => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,34 +1,17 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use collections::HashMap;
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationMetrics {
|
||||
pub true_positives: usize,
|
||||
pub false_positives: usize,
|
||||
pub false_negatives: usize,
|
||||
#[derive(Default, Debug, Clone)]
|
||||
struct ClassificationMetrics {
|
||||
true_positives: usize,
|
||||
false_positives: usize,
|
||||
false_negatives: usize,
|
||||
}
|
||||
|
||||
impl ClassificationMetrics {
|
||||
pub fn from_sets(
|
||||
expected: &HashSet<String>,
|
||||
actual: &HashSet<String>,
|
||||
) -> ClassificationMetrics {
|
||||
let true_positives = expected.intersection(actual).count();
|
||||
let false_positives = actual.difference(expected).count();
|
||||
let false_negatives = expected.difference(actual).count();
|
||||
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -56,27 +39,7 @@ impl ClassificationMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggregate<'a>(
|
||||
scores: impl Iterator<Item = &'a ClassificationMetrics>,
|
||||
) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
|
||||
for score in scores {
|
||||
true_positives += score.true_positives;
|
||||
false_positives += score.false_positives;
|
||||
false_negatives += score.false_negatives;
|
||||
}
|
||||
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn precision(&self) -> f64 {
|
||||
fn precision(&self) -> f64 {
|
||||
if self.true_positives + self.false_positives == 0 {
|
||||
0.0
|
||||
} else {
|
||||
@@ -84,42 +47,13 @@ impl ClassificationMetrics {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn recall(&self) -> f64 {
|
||||
fn recall(&self) -> f64 {
|
||||
if self.true_positives + self.false_negatives == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.true_positives as f64 / (self.true_positives + self.false_negatives) as f64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn f1_score(&self) -> f64 {
|
||||
let recall = self.recall();
|
||||
let precision = self.precision();
|
||||
if precision + recall == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
2.0 * precision * recall / (precision + recall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn line_match_score(
|
||||
expected_patch: &[DiffLine],
|
||||
actual_patch: &[DiffLine],
|
||||
) -> ClassificationMetrics {
|
||||
let expected_change_lines = expected_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
let actual_change_lines = actual_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
}
|
||||
|
||||
enum ChrfWhitespace {
|
||||
@@ -135,55 +69,26 @@ const CHR_F_WHITESPACE: ChrfWhitespace = ChrfWhitespace::Ignore;
|
||||
/// Computes a delta-chrF score that compares two sets of edits.
|
||||
///
|
||||
/// This metric works by:
|
||||
/// 1. Reconstructing original, golden (expected result), and actual texts from diffs
|
||||
/// 2. Computing n-gram count differences (deltas) between original→golden and original→actual
|
||||
/// 3. Comparing these deltas to measure how well actual edits match expected edits
|
||||
pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
|
||||
// Reconstruct texts from diffs
|
||||
let mut original_text = String::new(); // state of the text before any edits
|
||||
let mut golden_text = String::new(); // text after applying golden edits
|
||||
let mut actual_text = String::new(); // text after applying actual edits
|
||||
|
||||
for line in expected {
|
||||
match line {
|
||||
DiffLine::Context(s) => {
|
||||
original_text.push_str(s);
|
||||
golden_text.push_str(s);
|
||||
}
|
||||
DiffLine::Deletion(s) => {
|
||||
original_text.push_str(s);
|
||||
}
|
||||
DiffLine::Addition(s) => {
|
||||
golden_text.push_str(s);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
for line in actual {
|
||||
match line {
|
||||
DiffLine::Context(s) | DiffLine::Addition(s) => {
|
||||
actual_text.push_str(s);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Edge case
|
||||
if original_text == golden_text && golden_text == actual_text {
|
||||
/// 1. Computing n-gram count differences (deltas) between original→expected and original→actual
|
||||
/// 2. Comparing these deltas to measure how well actual edits match expected edits
|
||||
///
|
||||
/// Returns a score from 0.0 to 100.0, where 100.0 means the actual edits perfectly match
|
||||
/// the expected edits.
|
||||
pub fn delta_chr_f(original: &str, expected: &str, actual: &str) -> f64 {
|
||||
// Edge case: if all texts are identical, the edits match perfectly
|
||||
if original == expected && expected == actual {
|
||||
return 100.0;
|
||||
}
|
||||
|
||||
// Compute the metric
|
||||
let original_ngrams = chr_f_ngram_counts(&original_text);
|
||||
let golden_ngrams = chr_f_ngram_counts(&golden_text);
|
||||
let actual_ngrams = chr_f_ngram_counts(&actual_text);
|
||||
let original_ngrams = chr_f_ngram_counts(original);
|
||||
let expected_ngrams = chr_f_ngram_counts(expected);
|
||||
let actual_ngrams = chr_f_ngram_counts(actual);
|
||||
|
||||
let mut total_precision = 0.0;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for order in 0..CHR_F_CHAR_ORDER {
|
||||
let expected_delta = compute_ngram_delta(&golden_ngrams[order], &original_ngrams[order]);
|
||||
let expected_delta = compute_ngram_delta(&expected_ngrams[order], &original_ngrams[order]);
|
||||
let actual_delta = compute_ngram_delta(&actual_ngrams[order], &original_ngrams[order]);
|
||||
|
||||
if expected_delta.is_empty() && actual_delta.is_empty() {
|
||||
@@ -255,7 +160,7 @@ fn ngram_delta_to_counts(delta: &CountsDelta) -> Counts {
|
||||
for (ngram, &delta) in delta {
|
||||
if delta > 0 {
|
||||
counts.insert(ngram.clone(), delta as usize);
|
||||
} else {
|
||||
} else if delta < 0 {
|
||||
counts.insert(format!("¬{ngram}"), delta.unsigned_abs());
|
||||
}
|
||||
}
|
||||
@@ -278,94 +183,68 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_perfect_match() {
|
||||
let diff = vec![
|
||||
DiffLine::Context("fn main() {"),
|
||||
DiffLine::Deletion(" println!(\"Hello\");"),
|
||||
DiffLine::Addition(" println!(\"Hello, World!\");"),
|
||||
DiffLine::Context("}"),
|
||||
];
|
||||
let original = "fn main() { println!(\"Hello\");}";
|
||||
let expected = "fn main() { println!(\"Hello, World!\");}";
|
||||
|
||||
let score = delta_chr_f(&diff, &diff);
|
||||
let score = delta_chr_f(original, expected, expected);
|
||||
assert!((score - 100.0).abs() < 1e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_wrong_edit() {
|
||||
// When the edit is wrong
|
||||
let expected = vec![
|
||||
DiffLine::Context("one "),
|
||||
DiffLine::Deletion("two "),
|
||||
DiffLine::Context("three"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("one "),
|
||||
DiffLine::Context("two "),
|
||||
DiffLine::Deletion("three"),
|
||||
DiffLine::Addition("four"),
|
||||
];
|
||||
let original = "one two three";
|
||||
let expected = "one three"; // deleted "two "
|
||||
let actual = "one two four"; // deleted "three", added "four"
|
||||
|
||||
// Then the score should be low
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score > 20.0 && score < 40.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_partial_match() {
|
||||
let expected = vec![
|
||||
DiffLine::Deletion("let x = 42;"),
|
||||
DiffLine::Addition("let x = 100;"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Deletion("let x = 42;"),
|
||||
DiffLine::Addition("let x = 99;"),
|
||||
];
|
||||
let original = "let x = 42;";
|
||||
let expected = "let x = 100;";
|
||||
let actual = "let x = 99;";
|
||||
|
||||
// We got the edit location right, but the replacement text is wrong.
|
||||
// Deleted ngrams will match, bringing the score somewhere in the middle.
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score > 40.0 && score < 60.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_missed_edit() {
|
||||
// When predictions makes no changes
|
||||
let expected = vec![
|
||||
DiffLine::Context("prefix "),
|
||||
DiffLine::Deletion("old"),
|
||||
DiffLine::Addition("new"),
|
||||
DiffLine::Context(" suffix"),
|
||||
];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("prefix "),
|
||||
DiffLine::Context("old"),
|
||||
DiffLine::Context(" suffix"),
|
||||
];
|
||||
let original = "prefix old suffix";
|
||||
let expected = "prefix new suffix";
|
||||
let actual = "prefix old suffix"; // no change
|
||||
|
||||
// Then the score should be low (all expected changes are false negatives)
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_extra_edit() {
|
||||
// When adding unexpected content
|
||||
let expected = vec![DiffLine::Context("hello"), DiffLine::Context("world")];
|
||||
|
||||
let actual = vec![
|
||||
DiffLine::Context("hello"),
|
||||
DiffLine::Addition("extra"),
|
||||
DiffLine::Context("world"),
|
||||
];
|
||||
let original = "helloworld";
|
||||
let expected = "helloworld"; // no change expected
|
||||
let actual = "helloextraworld"; // added "extra"
|
||||
|
||||
// Then the score should be low (all actual changes are false positives)
|
||||
let score = delta_chr_f(&expected, &actual);
|
||||
let score = delta_chr_f(original, expected, actual);
|
||||
assert!(score < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_no_changes() {
|
||||
let text = "unchanged text";
|
||||
let score = delta_chr_f(text, text, text);
|
||||
assert!((score - 100.0).abs() < 1e-2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,11 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LATEST_FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("latest_failed"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static SYNTHESIZE_STATE_FILE: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| DATA_DIR.join("synthesize_state.json"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
|
||||
@@ -28,12 +28,16 @@ pub async fn run_prediction(
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
if !example.predictions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider.context("provider is required")?;
|
||||
|
||||
if let Some(existing_prediction) = example.predictions.first() {
|
||||
if existing_prediction.provider == provider {
|
||||
return Ok(());
|
||||
} else {
|
||||
example.predictions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
if matches!(
|
||||
@@ -184,7 +188,9 @@ pub async fn run_prediction(
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
prediction
|
||||
.edit_preview
|
||||
.as_unified_diff(prediction.snapshot.file(), &prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ pub enum Step {
|
||||
FormatPrompt,
|
||||
Predict,
|
||||
Score,
|
||||
Synthesize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -62,6 +63,7 @@ impl Step {
|
||||
Step::FormatPrompt => "Format",
|
||||
Step::Predict => "Predict",
|
||||
Step::Score => "Score",
|
||||
Step::Synthesize => "Synthesize",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +74,7 @@ impl Step {
|
||||
Step::FormatPrompt => "\x1b[34m",
|
||||
Step::Predict => "\x1b[32m",
|
||||
Step::Score => "\x1b[31m",
|
||||
Step::Synthesize => "\x1b[36m",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ use crate::{
|
||||
PredictArgs,
|
||||
example::{Example, ExampleScore},
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
metrics,
|
||||
predict::run_prediction,
|
||||
progress::{Progress, Step},
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::udiff::apply_diff_to_string;
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -27,18 +28,32 @@ pub async fn run_scoring(
|
||||
|
||||
let _progress = Progress::global().start(Step::Score, &example.spec.name);
|
||||
|
||||
let expected_patch = parse_patch(&example.spec.expected_patch);
|
||||
let original_text = &example.buffer.as_ref().unwrap().content;
|
||||
let expected_texts: Vec<String> = example
|
||||
.spec
|
||||
.expected_patches
|
||||
.iter()
|
||||
.map(|patch| {
|
||||
apply_diff_to_string(original_text, patch)
|
||||
.with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let mut scores = vec![];
|
||||
|
||||
for pred in &example.predictions {
|
||||
let actual_patch = parse_patch(&pred.actual_patch);
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
|
||||
|
||||
for prediction in &example.predictions {
|
||||
let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
|
||||
Ok(text) => text,
|
||||
Err(_) => {
|
||||
scores.push(ExampleScore { delta_chr_f: 0.0 });
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let best_delta_chr_f = expected_texts
|
||||
.iter()
|
||||
.map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
|
||||
.fold(0.0, f32::max);
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f,
|
||||
line_match,
|
||||
delta_chr_f: best_delta_chr_f,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -46,42 +61,25 @@ pub async fn run_scoring(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
patch.lines().map(DiffLine::parse).collect()
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example]) {
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
|
||||
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
|
||||
);
|
||||
eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
let mut all_line_match_scores = Vec::new();
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
|
||||
for example in examples {
|
||||
for score in example.score.iter() {
|
||||
let line_match = &score.line_match;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
truncate_name(&example.spec.name, 30),
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0,
|
||||
"{:<50} {:>9.2}",
|
||||
truncate_name(&example.spec.name, 50),
|
||||
score.delta_chr_f
|
||||
);
|
||||
|
||||
all_line_match_scores.push(line_match.clone());
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
}
|
||||
}
|
||||
@@ -90,22 +88,11 @@ pub fn print_report(examples: &[Example]) {
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
if !all_line_match_scores.is_empty() {
|
||||
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
|
||||
if !all_delta_chr_f_scores.is_empty() {
|
||||
let avg_delta_chr_f: f32 =
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
"TOTAL",
|
||||
total_line_match.true_positives,
|
||||
total_line_match.false_positives,
|
||||
total_line_match.false_negatives,
|
||||
total_line_match.precision() * 100.0,
|
||||
total_line_match.recall() * 100.0,
|
||||
total_line_match.f1_score() * 100.0,
|
||||
avg_delta_chr_f
|
||||
);
|
||||
eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
902
crates/edit_prediction_cli/src/synthesize.rs
Normal file
902
crates/edit_prediction_cli/src/synthesize.rs
Normal file
@@ -0,0 +1,902 @@
|
||||
use crate::{
|
||||
anthropic_client::PlainLlmClient,
|
||||
git::{ensure_repo_cloned, run_git},
|
||||
paths::{FAILED_EXAMPLES_DIR, LATEST_FAILED_EXAMPLES_DIR, SYNTHESIZE_STATE_FILE},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anthropic::ResponseContent;
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::Local;
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::{
|
||||
example_spec::ExampleSpec,
|
||||
udiff::{apply_diff_to_string, edits_for_diff},
|
||||
};
|
||||
use indoc::indoc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SynthesizeConfig {
|
||||
pub repo_url: String,
|
||||
pub count: usize,
|
||||
pub max_commits: usize,
|
||||
pub output_dir: PathBuf,
|
||||
pub require_context: bool,
|
||||
pub fresh: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct SynthesizeState {
|
||||
repositories: HashMap<String, RepoState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
struct RepoState {
|
||||
processed_commits: HashSet<String>,
|
||||
examples_generated: usize,
|
||||
}
|
||||
|
||||
impl SynthesizeState {
|
||||
fn load() -> Self {
|
||||
if SYNTHESIZE_STATE_FILE.exists() {
|
||||
std::fs::read_to_string(&*SYNTHESIZE_STATE_FILE)
|
||||
.ok()
|
||||
.and_then(|s| serde_json::from_str(&s).ok())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self) -> Result<()> {
|
||||
let content = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(&*SYNTHESIZE_STATE_FILE, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_processed(&self, repo_url: &str, commit_sha: &str) -> bool {
|
||||
self.repositories
|
||||
.get(repo_url)
|
||||
.is_some_and(|repo| repo.processed_commits.contains(commit_sha))
|
||||
}
|
||||
|
||||
fn mark_processed(&mut self, repo_url: &str, commit_sha: &str, examples_count: usize) {
|
||||
let repo = self.repositories.entry(repo_url.to_string()).or_default();
|
||||
repo.processed_commits.insert(commit_sha.to_string());
|
||||
repo.examples_generated += examples_count;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CommitInfo {
|
||||
sha: String,
|
||||
parent_sha: String,
|
||||
message: String,
|
||||
diff: String,
|
||||
expanded_diff: String,
|
||||
}
|
||||
|
||||
/// Claude's response parsed into structured form
|
||||
#[derive(Debug)]
|
||||
struct ClaudeResponse {
|
||||
name: String,
|
||||
reasoning: String,
|
||||
edit_history_hunks: Vec<String>,
|
||||
expected_patch_hunks: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn run_synthesize(config: SynthesizeConfig) -> Result<()> {
|
||||
let mut state = if config.fresh {
|
||||
SynthesizeState::default()
|
||||
} else {
|
||||
SynthesizeState::load()
|
||||
};
|
||||
|
||||
std::fs::create_dir_all(&config.output_dir)?;
|
||||
std::fs::create_dir_all(&*FAILED_EXAMPLES_DIR)?;
|
||||
|
||||
// Create "latest_failed" symlink pointing to this run's failed directory
|
||||
if LATEST_FAILED_EXAMPLES_DIR.is_symlink() {
|
||||
std::fs::remove_file(&*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&*FAILED_EXAMPLES_DIR, &*LATEST_FAILED_EXAMPLES_DIR)?;
|
||||
|
||||
let progress = Progress::global();
|
||||
progress.set_total_examples(config.count);
|
||||
|
||||
let clone_progress = progress.start(Step::Synthesize, "clone");
|
||||
let repo_path = ensure_repo_cloned(&config.repo_url).await?;
|
||||
drop(clone_progress);
|
||||
|
||||
let client = PlainLlmClient::new()?;
|
||||
let mut examples_generated = 0;
|
||||
let mut commits_skipped = 0;
|
||||
let batch_size = config.max_commits;
|
||||
|
||||
'outer: loop {
|
||||
let list_progress = progress.start(Step::Synthesize, "list-commits");
|
||||
let commits = list_commits(&repo_path, batch_size, commits_skipped).await?;
|
||||
drop(list_progress);
|
||||
|
||||
if commits.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
commits_skipped += commits.len();
|
||||
|
||||
for commit in commits {
|
||||
if examples_generated >= config.count {
|
||||
break 'outer;
|
||||
}
|
||||
|
||||
if !config.fresh && state.is_processed(&config.repo_url, &commit.sha) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if should_skip_commit(&commit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let commit_label = format!(
|
||||
"{} {}",
|
||||
&commit.sha[..8],
|
||||
truncate_message(&commit.message, 40)
|
||||
);
|
||||
let step_progress = Arc::new(progress.start(Step::Synthesize, &commit_label));
|
||||
|
||||
// Single Claude call to identify and copy hunks
|
||||
step_progress.set_substatus("analyzing...");
|
||||
let claude_response =
|
||||
match analyze_commit(&client, &config, &commit, step_progress.clone()).await {
|
||||
Ok(Some(response)) => response,
|
||||
Ok(None) => {
|
||||
step_progress.set_info("no pattern", InfoStyle::Normal);
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 0);
|
||||
state.save()?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
step_progress.set_info(format!("error: {:?}", e), InfoStyle::Warning);
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 0);
|
||||
state.save()?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Validate and build the example
|
||||
step_progress.set_substatus("validating...");
|
||||
match build_example(&config, &commit, &repo_path, &claude_response).await {
|
||||
Ok(spec) => {
|
||||
let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S");
|
||||
let filename = format!("{}.md", timestamp);
|
||||
let path = config.output_dir.join(&filename);
|
||||
std::fs::write(&path, spec.to_markdown())?;
|
||||
examples_generated += 1;
|
||||
step_progress.set_info(filename, InfoStyle::Normal);
|
||||
}
|
||||
Err(rejection_reason) => {
|
||||
log::debug!("Example rejected: {}", rejection_reason);
|
||||
let timestamp = Local::now().format("%Y-%m-%d--%H-%M-%S%.3f");
|
||||
let filename = format!("{}.md", timestamp);
|
||||
let path = FAILED_EXAMPLES_DIR.join(&filename);
|
||||
let content = format_rejected_example(&claude_response, &rejection_reason);
|
||||
if let Err(e) = std::fs::write(&path, content) {
|
||||
log::warn!("Failed to write rejected example: {:?}", e);
|
||||
}
|
||||
step_progress.set_info(format!("rejected: {}", filename), InfoStyle::Warning);
|
||||
}
|
||||
}
|
||||
|
||||
state.mark_processed(&config.repo_url, &commit.sha, 1);
|
||||
state.save()?;
|
||||
}
|
||||
}
|
||||
|
||||
progress.finalize();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_message(msg: &str, max_len: usize) -> String {
|
||||
let first_line = msg.lines().next().unwrap_or("");
|
||||
if first_line.len() <= max_len {
|
||||
first_line.to_string()
|
||||
} else {
|
||||
format!("{}...", &first_line[..max_len - 3])
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_commit(commit: &CommitInfo) -> bool {
|
||||
let lines_changed = commit
|
||||
.diff
|
||||
.lines()
|
||||
.filter(|l| l.starts_with('+') || l.starts_with('-'))
|
||||
.count();
|
||||
lines_changed < 10
|
||||
|| lines_changed > 1000
|
||||
|| is_non_code_commit(commit)
|
||||
|| is_rename_commit(commit)
|
||||
}
|
||||
|
||||
fn is_non_code_commit(commit: &CommitInfo) -> bool {
|
||||
let non_code_extensions = [
|
||||
".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".lock", ".svg", ".png", ".jpg", ".gif",
|
||||
".ico", ".woff", ".ttf", ".eot",
|
||||
];
|
||||
|
||||
let diff_files: Vec<&str> = commit
|
||||
.diff
|
||||
.lines()
|
||||
.filter(|l| l.starts_with("+++ b/") || l.starts_with("--- a/"))
|
||||
.filter_map(|l| {
|
||||
l.strip_prefix("+++ b/")
|
||||
.or_else(|| l.strip_prefix("--- a/"))
|
||||
})
|
||||
.collect();
|
||||
|
||||
if diff_files.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
diff_files
|
||||
.iter()
|
||||
.all(|f| non_code_extensions.iter().any(|ext| f.ends_with(ext)))
|
||||
}
|
||||
|
||||
fn is_rename_commit(commit: &CommitInfo) -> bool {
|
||||
commit.diff.contains("similarity index")
|
||||
|| commit.diff.contains("rename from")
|
||||
|| commit.diff.contains("rename to")
|
||||
}
|
||||
|
||||
async fn list_commits(
|
||||
repo_path: &Path,
|
||||
max_commits: usize,
|
||||
skip: usize,
|
||||
) -> Result<Vec<CommitInfo>> {
|
||||
let output = run_git(
|
||||
repo_path,
|
||||
&[
|
||||
"log",
|
||||
"--no-merges",
|
||||
&format!("--skip={}", skip),
|
||||
&format!("-{}", max_commits),
|
||||
"--format=%H|%P|%s",
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut commits = Vec::new();
|
||||
for line in output.lines() {
|
||||
let parts: Vec<&str> = line.splitn(3, '|').collect();
|
||||
if parts.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
let sha = parts[0].to_string();
|
||||
let parent_sha = parts[1].split_whitespace().next().unwrap_or("").to_string();
|
||||
if parent_sha.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get standard diff (for skip checks)
|
||||
let diff = run_git(repo_path, &["show", "--format=", &sha])
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Get expanded diff with 30 lines of context
|
||||
let expanded_diff = run_git(repo_path, &["show", "-U30", "--format=", &sha])
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
commits.push(CommitInfo {
|
||||
sha,
|
||||
parent_sha,
|
||||
message: parts[2].to_string(),
|
||||
diff,
|
||||
expanded_diff,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(commits)
|
||||
}
|
||||
|
||||
fn build_prompt(config: &SynthesizeConfig, commit: &CommitInfo) -> String {
|
||||
let context_guidance = if config.require_context {
|
||||
"IMPORTANT: Only identify patterns that REQUIRE reading context from other files to make the prediction. \
|
||||
Single-file patterns (where the edit history and expected patch are in the same file) are NOT acceptable \
|
||||
unless the pattern clearly requires understanding code from other files."
|
||||
} else {
|
||||
"Both single-file and multi-file patterns are acceptable."
|
||||
};
|
||||
|
||||
format!(
|
||||
indoc! {r#"
|
||||
You are analyzing a git commit to construct a realistic edit prediction example.
|
||||
|
||||
Your goal is to tell the story of a programmer's editing session: what sequence of changes did they make, and what change logically comes next? We use these examples to train a model to predict edits, so the quality of the EDIT HISTORY is what matters most.
|
||||
|
||||
An edit prediction example consists of:
|
||||
1. **Edit History**: 3-6 hunks showing what the programmer did BEFORE making the expected patch. This is the most important part - it must tell a coherent story of the changes leading up to the prediction.
|
||||
2. **Expected Patch**: One small hunk that logically follows from the edit history.
|
||||
|
||||
{context_guidance}
|
||||
|
||||
## What Makes a Good Example
|
||||
|
||||
The edit history should read like a story: "First the programmer changed X, then Y, then Z, and now they need to change W."
|
||||
|
||||
GOOD examples (rich sequences with 3+ steps):
|
||||
- Removing a parameter: docstring update → constructor change → field removal → (predict) usage site update
|
||||
- Adding a feature: type definition → first usage → second usage → (predict) third usage
|
||||
- Bug fix pattern: fix in file A → fix in file B → fix in file C → (predict) fix in file D
|
||||
|
||||
BAD examples (respond NO_PATTERN):
|
||||
- Commits where all changes are independent (no narrative thread)
|
||||
- Simple find-and-replace (renaming, version bumps)
|
||||
- Documentation-only or config-only changes
|
||||
- Changes where you can only find 1-2 hunks for the edit history
|
||||
|
||||
## Commit Information
|
||||
|
||||
Repository: {repo_url}
|
||||
Commit: {sha}
|
||||
Message: {message}
|
||||
|
||||
## Diff (30 lines context)
|
||||
|
||||
```diff
|
||||
{expanded_diff}
|
||||
```
|
||||
|
||||
## Your Task
|
||||
|
||||
First, THINK through whether this commit can support a good example:
|
||||
|
||||
1. What is the high-level pattern in this commit?
|
||||
2. Can you identify at least 4 related hunks (3 for edit history + 1 for expected patch)?
|
||||
3. What would be the narrative? (First... then... then... finally predict...)
|
||||
4. Which specific hunk should be the expected patch (the "punchline")?
|
||||
|
||||
If you cannot construct a coherent 3+ hunk story, respond with just:
|
||||
NO_PATTERN: <brief reason>
|
||||
|
||||
If you CAN construct a good example, respond in this format:
|
||||
|
||||
ANALYSIS:
|
||||
Pattern: <one sentence describing the pattern>
|
||||
Steps:
|
||||
1. <file:line-range> - <what this hunk does>
|
||||
2. <file:line-range> - <what this hunk does>
|
||||
3. <file:line-range> - <what this hunk does>
|
||||
4. [EXPECTED PATCH] <file:line-range> - <what this hunk does>
|
||||
|
||||
NAME: <short description, like a commit message, under 60 chars>
|
||||
|
||||
EDIT_HISTORY:
|
||||
|
||||
Hunk 1:
|
||||
```diff
|
||||
--- a/src/models/user.py
|
||||
+++ b/src/models/user.py
|
||||
@@ -15,7 +15,6 @@ class User:
|
||||
"""A user in the system.
|
||||
|
||||
Attributes:
|
||||
- email: The user's email address.
|
||||
name: The user's display name.
|
||||
"""
|
||||
```
|
||||
|
||||
Hunk 2:
|
||||
```diff
|
||||
--- a/src/models/user.py
|
||||
+++ b/src/models/user.py
|
||||
@@ -25,10 +24,9 @@ class User:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
- email: str,
|
||||
created_at: datetime,
|
||||
):
|
||||
self.name = name
|
||||
- self.email = email
|
||||
self.created_at = created_at
|
||||
```
|
||||
|
||||
Hunk 3:
|
||||
```diff
|
||||
--- a/src/api/handlers.py
|
||||
+++ b/src/api/handlers.py
|
||||
@@ -42,7 +42,6 @@ def create_user(request):
|
||||
data = request.json()
|
||||
user = User(
|
||||
name=data["name"],
|
||||
- email=data["email"],
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
return user.save()
|
||||
```
|
||||
|
||||
EXPECTED_PATCH:
|
||||
```diff
|
||||
--- a/src/api/handlers.py
|
||||
+++ b/src/api/handlers.py
|
||||
@@ -58,7 +57,6 @@ def update_user(request, user_id):
|
||||
user = User.get(user_id)
|
||||
user.name = data.get("name", user.name)
|
||||
- user.email = data.get("email", user.email)
|
||||
user.save()
|
||||
return user
|
||||
```
|
||||
|
||||
## Requirements for the diffs
|
||||
|
||||
Edit history:
|
||||
- MUST have 3-6 hunks (if you cannot find 3+, respond NO_PATTERN instead)
|
||||
- Each hunk needs file headers (--- a/path and +++ b/path)
|
||||
- Hunks must be valid unified diffs that apply to the parent commit
|
||||
- Order hunks as a programmer would naturally make the changes
|
||||
|
||||
Expected patch:
|
||||
- Must be a SINGLE hunk from a SINGLE file
|
||||
- Must be SMALL: 1-15 changed lines (not counting context)
|
||||
- Must be clearly predictable from the edit history narrative
|
||||
"#},
|
||||
context_guidance = context_guidance,
|
||||
repo_url = config.repo_url,
|
||||
sha = commit.sha,
|
||||
message = commit.message,
|
||||
expanded_diff = commit.expanded_diff,
|
||||
)
|
||||
}
|
||||
|
||||
async fn analyze_commit(
|
||||
client: &PlainLlmClient,
|
||||
config: &SynthesizeConfig,
|
||||
commit: &CommitInfo,
|
||||
step_progress: Arc<StepProgress>,
|
||||
) -> Result<Option<ClaudeResponse>> {
|
||||
use anthropic::{Message, RequestContent, Role};
|
||||
|
||||
let prompt = build_prompt(config, commit);
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt,
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let response = client
|
||||
.generate_streaming("claude-sonnet-4-5", 8192, messages, |chars, _text| {
|
||||
step_progress.set_substatus(format!("analyzing: {:.1}K", chars as f64 / 1000.0));
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Extract text content from response
|
||||
let response_text: String = response
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ResponseContent::Text { text } = block {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
parse_claude_response(&response_text)
|
||||
}
|
||||
|
||||
fn parse_claude_response(response: &str) -> Result<Option<ClaudeResponse>> {
|
||||
// Check for NO_PATTERN
|
||||
if response.contains("NO_PATTERN:") {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Parse NAME
|
||||
let name = response
|
||||
.lines()
|
||||
.find(|l| l.starts_with("NAME:"))
|
||||
.map(|l| l.strip_prefix("NAME:").unwrap_or("").trim().to_string())
|
||||
.unwrap_or_else(|| "unnamed example".to_string());
|
||||
|
||||
// Parse ANALYSIS section (Claude's planning) - this is the primary reasoning
|
||||
let reasoning = extract_section(
|
||||
response,
|
||||
"ANALYSIS:",
|
||||
&["NAME:", "REASONING:", "EDIT_HISTORY:", "EXPECTED_PATCH:"],
|
||||
)
|
||||
.unwrap_or_default();
|
||||
|
||||
// Parse EDIT_HISTORY diff block
|
||||
let edit_history_hunks = extract_diff_block(response, "EDIT_HISTORY:")?;
|
||||
|
||||
// Parse EXPECTED_PATCH diff block
|
||||
let expected_patch_hunks = extract_diff_block(response, "EXPECTED_PATCH:")?;
|
||||
|
||||
if edit_history_hunks.is_empty() {
|
||||
anyhow::bail!("No edit history hunks found in response");
|
||||
}
|
||||
if expected_patch_hunks.is_empty() {
|
||||
anyhow::bail!("No expected patch hunks found in response");
|
||||
}
|
||||
|
||||
Ok(Some(ClaudeResponse {
|
||||
name,
|
||||
reasoning,
|
||||
edit_history_hunks,
|
||||
expected_patch_hunks,
|
||||
}))
|
||||
}
|
||||
|
||||
fn extract_section(text: &str, start_marker: &str, end_markers: &[&str]) -> Option<String> {
|
||||
let start_idx = text.find(start_marker)?;
|
||||
let content_start = start_idx + start_marker.len();
|
||||
|
||||
let end_idx = end_markers
|
||||
.iter()
|
||||
.filter_map(|marker| text[content_start..].find(marker))
|
||||
.min()
|
||||
.map(|idx| content_start + idx)
|
||||
.unwrap_or(text.len());
|
||||
|
||||
Some(text[content_start..end_idx].trim().to_string())
|
||||
}
|
||||
|
||||
fn extract_diff_block(text: &str, section_marker: &str) -> Result<Vec<String>> {
|
||||
let section_start = text
|
||||
.find(section_marker)
|
||||
.context(format!("Section {} not found", section_marker))?;
|
||||
|
||||
let after_marker = &text[section_start + section_marker.len()..];
|
||||
|
||||
// Find where the next major section starts (to bound our search)
|
||||
let section_end = ["EXPECTED_PATCH:", "## "]
|
||||
.iter()
|
||||
.filter(|&&m| m != section_marker)
|
||||
.filter_map(|marker| after_marker.find(marker))
|
||||
.min()
|
||||
.unwrap_or(after_marker.len());
|
||||
|
||||
let section_content = &after_marker[..section_end];
|
||||
|
||||
// Collect all ```diff blocks in this section
|
||||
let mut hunks = Vec::new();
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(diff_start) = section_content[search_start..].find("```diff") {
|
||||
let abs_diff_start = search_start + diff_start;
|
||||
let block_content_start = section_content[abs_diff_start..]
|
||||
.find('\n')
|
||||
.map(|i| abs_diff_start + i + 1)
|
||||
.unwrap_or(abs_diff_start);
|
||||
|
||||
if let Some(block_end_rel) = section_content[block_content_start..].find("```") {
|
||||
let block_end = block_content_start + block_end_rel;
|
||||
let diff_content = section_content[block_content_start..block_end].trim();
|
||||
|
||||
// Split this block into hunks (in case multiple hunks in one block)
|
||||
hunks.extend(split_into_hunks(diff_content));
|
||||
|
||||
search_start = block_end + 3;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if hunks.is_empty() {
|
||||
anyhow::bail!("No diff blocks found in section {}", section_marker);
|
||||
}
|
||||
|
||||
Ok(hunks)
|
||||
}
|
||||
|
||||
/// Split a diff block into individual hunks, preserving file headers
|
||||
fn split_into_hunks(diff: &str) -> Vec<String> {
|
||||
let mut hunks = Vec::new();
|
||||
let mut current_file_header: Option<String> = None;
|
||||
let mut current_hunk: Vec<String> = Vec::new();
|
||||
let mut in_hunk = false;
|
||||
|
||||
for line in diff.lines() {
|
||||
if line.starts_with("--- a/") || line.starts_with("--- /") {
|
||||
// Start of file header - flush previous hunk
|
||||
if in_hunk && !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
current_hunk.clear();
|
||||
}
|
||||
current_file_header = Some(line.to_string());
|
||||
in_hunk = false;
|
||||
} else if line.starts_with("+++ b/") || line.starts_with("+++ /") {
|
||||
if let Some(ref mut header) = current_file_header {
|
||||
header.push('\n');
|
||||
header.push_str(line);
|
||||
}
|
||||
} else if line.starts_with("@@ ") {
|
||||
// New hunk - flush previous
|
||||
if in_hunk && !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
current_hunk.clear();
|
||||
}
|
||||
current_hunk.push(line.to_string());
|
||||
in_hunk = true;
|
||||
} else if in_hunk {
|
||||
current_hunk.push(line.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Flush final hunk
|
||||
if !current_hunk.is_empty() {
|
||||
let mut hunk_text = String::new();
|
||||
if let Some(ref header) = current_file_header {
|
||||
hunk_text.push_str(header);
|
||||
hunk_text.push('\n');
|
||||
}
|
||||
hunk_text.push_str(¤t_hunk.join("\n"));
|
||||
hunks.push(hunk_text);
|
||||
}
|
||||
|
||||
hunks
|
||||
}
|
||||
|
||||
/// Validate Claude's output by applying diffs and build the ExampleSpec
|
||||
async fn build_example(
|
||||
config: &SynthesizeConfig,
|
||||
commit: &CommitInfo,
|
||||
repo_path: &Path,
|
||||
response: &ClaudeResponse,
|
||||
) -> Result<ExampleSpec, String> {
|
||||
// Validate expected patch hunks
|
||||
if response.expected_patch_hunks.len() != 1 {
|
||||
return Err(format!(
|
||||
"Expected exactly 1 expected patch hunk, got {}",
|
||||
response.expected_patch_hunks.len()
|
||||
));
|
||||
}
|
||||
|
||||
// Parse the expected patch to determine cursor file
|
||||
let expected_patch = &response.expected_patch_hunks[0];
|
||||
let cursor_file = extract_file_from_hunk(expected_patch)
|
||||
.ok_or_else(|| "Could not determine file from expected patch".to_string())?;
|
||||
|
||||
// Get the file content before the commit
|
||||
let before_content = run_git(
|
||||
repo_path,
|
||||
&["show", &format!("{}^:{}", commit.sha, cursor_file)],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to get file content for {}: {}", cursor_file, e))?;
|
||||
|
||||
// Build edit history diff from Claude's hunks
|
||||
let edit_history = response.edit_history_hunks.join("\n");
|
||||
|
||||
// Apply edit history to get intermediate state (validates edit history)
|
||||
let intermediate_state =
|
||||
apply_edit_history_to_content(&before_content, &edit_history, &cursor_file)?;
|
||||
|
||||
// Validate expected patch applies to intermediate state
|
||||
let expected_patch_with_header = ensure_diff_header(expected_patch, &cursor_file);
|
||||
apply_diff_to_string(&intermediate_state, &expected_patch_with_header)
|
||||
.map_err(|e| format!("Expected patch failed to apply: {}", e))?;
|
||||
|
||||
// Find where the expected patch edits would apply in the intermediate state
|
||||
let edits = edits_for_diff(&intermediate_state, &expected_patch_with_header)
|
||||
.map_err(|e| format!("Failed to parse expected patch: {}", e))?;
|
||||
if edits.is_empty() {
|
||||
return Err(
|
||||
"Could not locate expected patch in file (context not found or ambiguous)".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
// Use the start of the first edit for cursor positioning
|
||||
let cursor_byte_offset = edits[0].0.start;
|
||||
|
||||
// Extract excerpt around the edit location
|
||||
let (excerpt, cursor_offset) = extract_cursor_excerpt(&intermediate_state, cursor_byte_offset)?;
|
||||
|
||||
// Build the ExampleSpec and use set_cursor_excerpt to format with comment marker
|
||||
let comment_prefix = line_comment_prefix(&cursor_file);
|
||||
let reasoning_with_source = format!(
|
||||
"Source commit: {} ({})\n\n{}",
|
||||
commit.sha,
|
||||
truncate_message(&commit.message, 60),
|
||||
response.reasoning
|
||||
);
|
||||
let mut spec = ExampleSpec {
|
||||
name: response.name.clone(),
|
||||
repository_url: config.repo_url.clone(),
|
||||
revision: commit.parent_sha.clone(),
|
||||
tags: Vec::new(),
|
||||
reasoning: Some(reasoning_with_source),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: Arc::from(Path::new(&cursor_file)),
|
||||
cursor_position: String::new(),
|
||||
edit_history,
|
||||
expected_patches: vec![expected_patch_with_header],
|
||||
};
|
||||
spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
/// Extract file path from a hunk (looks for --- a/path or +++ b/path)
|
||||
fn extract_file_from_hunk(hunk: &str) -> Option<String> {
|
||||
for line in hunk.lines() {
|
||||
if let Some(path) = line.strip_prefix("+++ b/") {
|
||||
return Some(path.to_string());
|
||||
}
|
||||
if let Some(path) = line.strip_prefix("--- a/") {
|
||||
return Some(path.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Ensure a hunk has proper file headers
|
||||
fn ensure_diff_header(hunk: &str, file_path: &str) -> String {
|
||||
if hunk.contains("--- a/") || hunk.contains("+++ b/") {
|
||||
return hunk.to_string();
|
||||
}
|
||||
format!("--- a/{}\n+++ b/{}\n{}", file_path, file_path, hunk)
|
||||
}
|
||||
|
||||
/// Apply edit history to file content, only if hunks affect this file
|
||||
fn apply_edit_history_to_content(
|
||||
content: &str,
|
||||
edit_history: &str,
|
||||
cursor_file: &str,
|
||||
) -> Result<String, String> {
|
||||
// Extract just the hunks for this file from the edit history
|
||||
let file_diff = extract_file_diff_from_combined(edit_history, cursor_file);
|
||||
|
||||
if file_diff.is_empty() {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
apply_diff_to_string(content, &file_diff)
|
||||
.map_err(|e| format!("Failed to apply edit history: {}", e))
|
||||
}
|
||||
|
||||
/// Extract hunks for a specific file from a combined diff
|
||||
fn extract_file_diff_from_combined(combined_diff: &str, target_file: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut in_target_file = false;
|
||||
let mut found_header = false;
|
||||
|
||||
for line in combined_diff.lines() {
|
||||
if line.starts_with("--- a/") {
|
||||
let file = line.strip_prefix("--- a/").unwrap_or("");
|
||||
in_target_file = file == target_file;
|
||||
if in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
found_header = false;
|
||||
}
|
||||
} else if line.starts_with("+++ b/") && in_target_file {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
found_header = true;
|
||||
} else if in_target_file && found_header {
|
||||
if line.starts_with("--- a/") {
|
||||
break;
|
||||
}
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Extract a cursor position excerpt from content around a byte offset.
|
||||
/// Returns the excerpt and the cursor offset within the excerpt.
|
||||
fn extract_cursor_excerpt(
|
||||
content: &str,
|
||||
cursor_byte_offset: usize,
|
||||
) -> Result<(String, usize), String> {
|
||||
// Find the line containing the cursor
|
||||
let line_start = content[..cursor_byte_offset]
|
||||
.rfind('\n')
|
||||
.map(|pos| pos + 1)
|
||||
.unwrap_or(0);
|
||||
let line_end = content[cursor_byte_offset..]
|
||||
.find('\n')
|
||||
.map(|pos| cursor_byte_offset + pos)
|
||||
.unwrap_or(content.len());
|
||||
|
||||
// Get context lines before
|
||||
let lines_before: Vec<&str> = content[..line_start].lines().collect();
|
||||
let context_before: Vec<&str> = lines_before.iter().rev().take(3).rev().cloned().collect();
|
||||
|
||||
// Get context lines after
|
||||
let after_line_end = if line_end < content.len() {
|
||||
line_end + 1
|
||||
} else {
|
||||
line_end
|
||||
};
|
||||
let context_after: Vec<&str> = content[after_line_end..].lines().take(4).collect();
|
||||
|
||||
// The line containing the cursor
|
||||
let cursor_line = &content[line_start..line_end];
|
||||
let cursor_column = cursor_byte_offset - line_start;
|
||||
|
||||
// Build the excerpt
|
||||
let mut excerpt = String::new();
|
||||
for line in context_before {
|
||||
excerpt.push_str(line);
|
||||
excerpt.push('\n');
|
||||
}
|
||||
// Track where cursor will be in the excerpt
|
||||
let cursor_offset_in_excerpt = excerpt.len() + cursor_column;
|
||||
// Line containing cursor
|
||||
excerpt.push_str(cursor_line);
|
||||
excerpt.push('\n');
|
||||
for line in context_after {
|
||||
excerpt.push_str(line);
|
||||
excerpt.push('\n');
|
||||
}
|
||||
|
||||
// Trim trailing newline
|
||||
if excerpt.ends_with('\n') {
|
||||
excerpt.pop();
|
||||
}
|
||||
|
||||
Ok((excerpt, cursor_offset_in_excerpt))
|
||||
}
|
||||
|
||||
/// Get the line comment prefix for a file based on its extension
|
||||
fn line_comment_prefix(file_path: &str) -> &'static str {
|
||||
let extension = file_path.rsplit('.').next().unwrap_or("");
|
||||
match extension {
|
||||
"rs" | "c" | "cpp" | "cc" | "h" | "hpp" | "js" | "ts" | "tsx" | "jsx" | "go" | "java"
|
||||
| "swift" | "kt" | "kts" | "scala" | "cs" | "m" | "mm" | "zig" | "v" | "d" => "//",
|
||||
"py" | "rb" | "sh" | "bash" | "zsh" | "pl" | "pm" | "r" | "jl" | "yaml" | "yml"
|
||||
| "toml" | "coffee" | "cr" | "ex" | "exs" | "elixir" => "#",
|
||||
"lua" | "hs" | "sql" => "--",
|
||||
"lisp" | "clj" | "cljs" | "scm" | "rkt" | "el" => ";",
|
||||
"erl" | "hrl" => "%",
|
||||
_ => "//",
|
||||
}
|
||||
}
|
||||
|
||||
fn format_rejected_example(response: &ClaudeResponse, rejection_reason: &str) -> String {
|
||||
let mut content = String::new();
|
||||
content.push_str("# Rejected Example\n\n");
|
||||
content.push_str(&format!("## Name\n\n{}\n\n", response.name));
|
||||
content.push_str(&format!("## Reasoning\n\n{}\n\n", response.reasoning));
|
||||
content.push_str("## Edit History Hunks\n\n```diff\n");
|
||||
for hunk in &response.edit_history_hunks {
|
||||
content.push_str(hunk);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
content.push_str("```\n\n");
|
||||
content.push_str("## Expected Patch Hunks\n\n```diff\n");
|
||||
for hunk in &response.expected_patch_hunks {
|
||||
content.push_str(hunk);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
content.push_str("```\n\n");
|
||||
content.push_str(&format!("## Rejection Reason\n\n{}\n", rejection_reason));
|
||||
content
|
||||
}
|
||||
@@ -15,8 +15,7 @@ doctest = false
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
git.workspace = true
|
||||
log.workspace = true
|
||||
collections.workspace = true
|
||||
time.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
@@ -50,11 +49,18 @@ zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock.workspace = true
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
futures.workspace = true
|
||||
indoc.workspace = true
|
||||
language_model.workspace = true
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
release_channel.workspace = true
|
||||
semver.workspace = true
|
||||
serde_json.workspace = true
|
||||
theme = { workspace = true, features = ["test-support"] }
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
@@ -915,11 +915,8 @@ impl EditPredictionButton {
|
||||
.when(
|
||||
cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
|
||||
|this| {
|
||||
this.action(
|
||||
"Capture Edit Prediction Example",
|
||||
CaptureExample.boxed_clone(),
|
||||
)
|
||||
.action("Rate Predictions", RatePredictions.boxed_clone())
|
||||
this.action("Capture Prediction Example", CaptureExample.boxed_clone())
|
||||
.action("Rate Predictions", RatePredictions.boxed_clone())
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,25 +2,17 @@ mod edit_prediction_button;
|
||||
mod edit_prediction_context_view;
|
||||
mod rate_prediction_modal;
|
||||
|
||||
use std::any::{Any as _, TypeId};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use edit_prediction::{
|
||||
EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec,
|
||||
};
|
||||
use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag, capture_example};
|
||||
use edit_prediction_context_view::EditPredictionContextView;
|
||||
use editor::Editor;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use git::repository::DiffType;
|
||||
use gpui::{Window, actions};
|
||||
use language::ToPoint as _;
|
||||
use log;
|
||||
use gpui::actions;
|
||||
use language::language_settings::AllLanguageSettings;
|
||||
use project::DisableAiSettings;
|
||||
use rate_prediction_modal::RatePredictionsModal;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use text::ToOffset as _;
|
||||
use std::any::{Any as _, TypeId};
|
||||
use ui::{App, prelude::*};
|
||||
use workspace::{SplitDirection, Workspace};
|
||||
|
||||
@@ -56,7 +48,9 @@ pub fn init(cx: &mut App) {
|
||||
}
|
||||
});
|
||||
|
||||
workspace.register_action(capture_edit_prediction_example);
|
||||
workspace.register_action(|workspace, _: &CaptureExample, window, cx| {
|
||||
capture_example_as_markdown(workspace, window, cx);
|
||||
});
|
||||
workspace.register_action_renderer(|div, _, _, cx| {
|
||||
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
|
||||
div.when(has_flag, |div| {
|
||||
@@ -138,182 +132,48 @@ fn feature_gate_predict_edits_actions(cx: &mut App) {
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn capture_edit_prediction_example(
|
||||
fn capture_example_as_markdown(
|
||||
workspace: &mut Workspace,
|
||||
_: &CaptureExample,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(ep_store) = EditPredictionStore::try_global(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let project = workspace.project().clone();
|
||||
|
||||
let (worktree_root, repository) = {
|
||||
let project_ref = project.read(cx);
|
||||
let worktree_root = project_ref
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).abs_path());
|
||||
let repository = project_ref.active_repository(cx);
|
||||
(worktree_root, repository)
|
||||
};
|
||||
|
||||
let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else {
|
||||
log::error!("CaptureExampleSpec: missing worktree or active repository");
|
||||
return;
|
||||
};
|
||||
|
||||
let repository_snapshot = repository.read(cx).snapshot();
|
||||
if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() {
|
||||
log::error!(
|
||||
"repository is not at worktree root (repo={:?}, worktree={:?})",
|
||||
repository_snapshot.work_directory_abs_path,
|
||||
worktree_root
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(repository_url) = repository_snapshot
|
||||
.remote_origin_url
|
||||
.clone()
|
||||
.or_else(|| repository_snapshot.remote_upstream_url.clone())
|
||||
else {
|
||||
log::error!("active repository has no origin/upstream remote url");
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(revision) = repository_snapshot
|
||||
.head_commit
|
||||
.as_ref()
|
||||
.map(|commit| commit.sha.to_string())
|
||||
else {
|
||||
log::error!("active repository has no head commit");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut events = ep_store.update(cx, |store, cx| {
|
||||
store.edit_history_for_project_with_pause_split_last_event(&project, cx)
|
||||
});
|
||||
|
||||
let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
|
||||
log::error!("no active editor");
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(project_path) = editor.read(cx).project_path(cx) else {
|
||||
log::error!("active editor has no project path");
|
||||
return;
|
||||
};
|
||||
|
||||
let Some((buffer, cursor_anchor)) = editor
|
||||
.read(cx)
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx)
|
||||
else {
|
||||
log::error!("failed to resolve cursor buffer/anchor");
|
||||
return;
|
||||
};
|
||||
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let cursor_point = cursor_anchor.to_point(&snapshot);
|
||||
let (_editable_range, context_range) =
|
||||
edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
&snapshot,
|
||||
100,
|
||||
50,
|
||||
);
|
||||
|
||||
let cursor_path: Arc<Path> = repository
|
||||
.read(cx)
|
||||
.project_path_to_repo_path(&project_path, cx)
|
||||
.map(|repo_path| Path::new(repo_path.as_unix_str()).into())
|
||||
.unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into());
|
||||
|
||||
let cursor_position = {
|
||||
let context_start_offset = context_range.start.to_offset(&snapshot);
|
||||
let cursor_offset = cursor_anchor.to_offset(&snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
|
||||
let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
|
||||
if cursor_offset_in_excerpt <= excerpt.len() {
|
||||
excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
|
||||
}
|
||||
excerpt
|
||||
};
|
||||
|
||||
) -> Option<()> {
|
||||
let markdown_language = workspace
|
||||
.app_state()
|
||||
.languages
|
||||
.language_for_name("Markdown");
|
||||
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
let project = workspace.project().clone();
|
||||
let editor = workspace.active_item_as::<Editor>(cx)?;
|
||||
let editor = editor.read(cx);
|
||||
let (buffer, cursor_anchor) = editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.text_anchor_for_position(editor.selections.newest_anchor().head(), cx)?;
|
||||
let example = capture_example(project.clone(), buffer, cursor_anchor, cx)?;
|
||||
|
||||
let examples_dir = AllLanguageSettings::get_global(cx)
|
||||
.edit_predictions
|
||||
.examples_dir
|
||||
.clone();
|
||||
|
||||
cx.spawn_in(window, async move |workspace_entity, cx| {
|
||||
let markdown_language = markdown_language.await?;
|
||||
let example_spec = example.await?;
|
||||
let buffer = if let Some(dir) = examples_dir {
|
||||
fs.create_dir(&dir).await.ok();
|
||||
let mut path = dir.join(&example_spec.name.replace(' ', "--").replace(':', "-"));
|
||||
path.set_extension("md");
|
||||
project.update(cx, |project, cx| project.open_local_buffer(&path, cx))
|
||||
} else {
|
||||
project.update(cx, |project, cx| project.create_buffer(false, cx))
|
||||
}?
|
||||
.await?;
|
||||
|
||||
let uncommitted_diff_rx = repository.update(cx, |repository, cx| {
|
||||
repository.diff(DiffType::HeadToWorktree, cx)
|
||||
})?;
|
||||
|
||||
let uncommitted_diff = match uncommitted_diff_rx.await {
|
||||
Ok(Ok(diff)) => diff,
|
||||
Ok(Err(error)) => {
|
||||
log::error!("failed to compute uncommitted diff: {error:#}");
|
||||
return Ok(());
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("uncommitted diff channel dropped: {error:#}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let mut edit_history = String::new();
|
||||
let mut expected_patch = String::new();
|
||||
if let Some(last_event) = events.pop() {
|
||||
for event in &events {
|
||||
zeta_prompt::write_event(&mut edit_history, event);
|
||||
if !edit_history.ends_with('\n') {
|
||||
edit_history.push('\n');
|
||||
}
|
||||
edit_history.push('\n');
|
||||
}
|
||||
|
||||
zeta_prompt::write_event(&mut expected_patch, &last_event);
|
||||
}
|
||||
|
||||
let format =
|
||||
time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
|
||||
let name = match format {
|
||||
Ok(format) => {
|
||||
let now = time::OffsetDateTime::now_local()
|
||||
.unwrap_or_else(|_| time::OffsetDateTime::now_utc());
|
||||
now.format(&format)
|
||||
.unwrap_or_else(|_| "unknown-time".to_string())
|
||||
}
|
||||
Err(_) => "unknown-time".to_string(),
|
||||
};
|
||||
|
||||
let markdown = ExampleSpec {
|
||||
name,
|
||||
repository_url,
|
||||
revision,
|
||||
uncommitted_diff,
|
||||
cursor_path,
|
||||
cursor_position,
|
||||
edit_history,
|
||||
expected_patch,
|
||||
}
|
||||
.to_markdown();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.create_buffer(false, cx))?
|
||||
.await?;
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text(markdown, cx);
|
||||
buffer.set_text(example_spec.to_markdown(), cx);
|
||||
buffer.set_language(Some(markdown_language), cx);
|
||||
})?;
|
||||
|
||||
workspace_entity.update_in(cx, |workspace, window, cx| {
|
||||
workspace.add_item_to_active_pane(
|
||||
Box::new(
|
||||
@@ -327,4 +187,5 @@ fn capture_edit_prediction_example(
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
None
|
||||
}
|
||||
|
||||
@@ -25298,36 +25298,34 @@ impl EditorSnapshot {
|
||||
/// Returns the line delta from `base` to `line` in the multibuffer, ignoring wrapped lines.
|
||||
///
|
||||
/// This is positive if `base` is before `line`.
|
||||
fn relative_line_delta(&self, base: DisplayRow, line: DisplayRow) -> i64 {
|
||||
fn relative_line_delta(
|
||||
&self,
|
||||
base: DisplayRow,
|
||||
line: DisplayRow,
|
||||
consider_wrapped_lines: bool,
|
||||
) -> i64 {
|
||||
let point = DisplayPoint::new(line, 0).to_point(self);
|
||||
self.relative_line_delta_to_point(base, point)
|
||||
self.relative_line_delta_to_point(base, point, consider_wrapped_lines)
|
||||
}
|
||||
|
||||
/// Returns the line delta from `base` to `point` in the multibuffer, ignoring wrapped lines.
|
||||
/// Returns the line delta from `base` to `point` in the multibuffer.
|
||||
///
|
||||
/// This is positive if `base` is before `point`.
|
||||
pub fn relative_line_delta_to_point(&self, base: DisplayRow, point: Point) -> i64 {
|
||||
pub fn relative_line_delta_to_point(
|
||||
&self,
|
||||
base: DisplayRow,
|
||||
point: Point,
|
||||
consider_wrapped_lines: bool,
|
||||
) -> i64 {
|
||||
let base_point = DisplayPoint::new(base, 0).to_point(self);
|
||||
point.row as i64 - base_point.row as i64
|
||||
}
|
||||
|
||||
/// Returns the line delta from `base` to `line` in the multibuffer, counting wrapped lines.
|
||||
///
|
||||
/// This is positive if `base` is before `line`.
|
||||
fn relative_wrapped_line_delta(&self, base: DisplayRow, line: DisplayRow) -> i64 {
|
||||
let point = DisplayPoint::new(line, 0).to_point(self);
|
||||
self.relative_wrapped_line_delta_to_point(base, point)
|
||||
}
|
||||
|
||||
/// Returns the line delta from `base` to `point` in the multibuffer, counting wrapped lines.
|
||||
///
|
||||
/// This is positive if `base` is before `point`.
|
||||
pub fn relative_wrapped_line_delta_to_point(&self, base: DisplayRow, point: Point) -> i64 {
|
||||
let base_point = DisplayPoint::new(base, 0).to_point(self);
|
||||
let wrap_snapshot = self.wrap_snapshot();
|
||||
let base_wrap_row = wrap_snapshot.make_wrap_point(base_point, Bias::Left).row();
|
||||
let wrap_row = wrap_snapshot.make_wrap_point(point, Bias::Left).row();
|
||||
wrap_row.0 as i64 - base_wrap_row.0 as i64
|
||||
if consider_wrapped_lines {
|
||||
let wrap_snapshot = self.wrap_snapshot();
|
||||
let base_wrap_row = wrap_snapshot.make_wrap_point(base_point, Bias::Left).row();
|
||||
let wrap_row = wrap_snapshot.make_wrap_point(point, Bias::Left).row();
|
||||
wrap_row.0 as i64 - base_wrap_row.0 as i64
|
||||
} else {
|
||||
point.row as i64 - base_point.row as i64
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the unsigned relative line number to display for each row in `rows`.
|
||||
@@ -25339,23 +25337,21 @@ impl EditorSnapshot {
|
||||
relative_to: DisplayRow,
|
||||
count_wrapped_lines: bool,
|
||||
) -> HashMap<DisplayRow, u32> {
|
||||
let initial_offset = if count_wrapped_lines {
|
||||
self.relative_wrapped_line_delta(relative_to, rows.start)
|
||||
} else {
|
||||
self.relative_line_delta(relative_to, rows.start)
|
||||
};
|
||||
let display_row_infos = self
|
||||
.row_infos(rows.start)
|
||||
let initial_offset = self.relative_line_delta(relative_to, rows.start, count_wrapped_lines);
|
||||
|
||||
self.row_infos(rows.start)
|
||||
.take(rows.len())
|
||||
.enumerate()
|
||||
.map(|(i, row_info)| (DisplayRow(rows.start.0 + i as u32), row_info));
|
||||
display_row_infos
|
||||
.map(|(i, row_info)| (DisplayRow(rows.start.0 + i as u32), row_info))
|
||||
.filter(|(_row, row_info)| {
|
||||
row_info.buffer_row.is_some()
|
||||
|| (count_wrapped_lines && row_info.wrapped_buffer_row.is_some())
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(i, (row, _row_info))| (row, (initial_offset + i as i64).unsigned_abs() as u32))
|
||||
.flat_map(|(i, (row, _row_info))| {
|
||||
(row != relative_to)
|
||||
.then_some((row, (initial_offset + i as i64).unsigned_abs() as u32))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18346,7 +18346,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon
|
||||
);
|
||||
|
||||
update_test_project_settings(cx, |project_settings| {
|
||||
project_settings.lsp.insert(
|
||||
project_settings.lsp.0.insert(
|
||||
"Some other server name".into(),
|
||||
LspSettings {
|
||||
binary: None,
|
||||
@@ -18367,7 +18367,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon
|
||||
);
|
||||
|
||||
update_test_project_settings(cx, |project_settings| {
|
||||
project_settings.lsp.insert(
|
||||
project_settings.lsp.0.insert(
|
||||
language_server_name.into(),
|
||||
LspSettings {
|
||||
binary: None,
|
||||
@@ -18388,7 +18388,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon
|
||||
);
|
||||
|
||||
update_test_project_settings(cx, |project_settings| {
|
||||
project_settings.lsp.insert(
|
||||
project_settings.lsp.0.insert(
|
||||
language_server_name.into(),
|
||||
LspSettings {
|
||||
binary: None,
|
||||
@@ -18409,7 +18409,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon
|
||||
);
|
||||
|
||||
update_test_project_settings(cx, |project_settings| {
|
||||
project_settings.lsp.insert(
|
||||
project_settings.lsp.0.insert(
|
||||
language_server_name.into(),
|
||||
LspSettings {
|
||||
binary: None,
|
||||
@@ -28725,7 +28725,7 @@ fn test_relative_line_numbers(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
relative_number,
|
||||
snapshot
|
||||
.relative_line_delta(display_row, base_display_row)
|
||||
.relative_line_delta(display_row, base_display_row, false)
|
||||
.unsigned_abs() as u32,
|
||||
);
|
||||
}
|
||||
@@ -28735,6 +28735,7 @@ fn test_relative_line_numbers(cx: &mut TestAppContext) {
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, row)| (DisplayRow(row), i.abs_diff(wrapped_base_row) as u32))
|
||||
.filter(|(row, _)| *row != base_display_row)
|
||||
.collect_vec();
|
||||
let actual_relative_numbers = snapshot
|
||||
.calculate_relative_line_numbers(
|
||||
@@ -28751,7 +28752,7 @@ fn test_relative_line_numbers(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
relative_number,
|
||||
snapshot
|
||||
.relative_wrapped_line_delta(display_row, base_display_row)
|
||||
.relative_line_delta(display_row, base_display_row, true)
|
||||
.unsigned_abs() as u32,
|
||||
);
|
||||
}
|
||||
@@ -29602,6 +29603,17 @@ async fn test_newline_task_list_continuation(cx: &mut TestAppContext) {
|
||||
- [ ] ˇ
|
||||
"});
|
||||
|
||||
// Case 2.1: Works with uppercase checked marker too
|
||||
cx.set_state(indoc! {"
|
||||
- [X] completed taskˇ
|
||||
"});
|
||||
cx.update_editor(|e, window, cx| e.newline(&Newline, window, cx));
|
||||
cx.wait_for_autoindent_applied().await;
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [X] completed task
|
||||
- [ ] ˇ
|
||||
"});
|
||||
|
||||
// Case 3: Cursor position doesn't matter - content after marker is what counts
|
||||
cx.set_state(indoc! {"
|
||||
- [ ] taˇsk
|
||||
|
||||
@@ -4611,15 +4611,15 @@ impl EditorElement {
|
||||
);
|
||||
|
||||
let line_number = show_line_numbers.then(|| {
|
||||
let relative_number = relative_to.and_then(|base| match relative_line_numbers {
|
||||
RelativeLineNumbers::Disabled => None,
|
||||
RelativeLineNumbers::Enabled => {
|
||||
Some(snapshot.relative_line_delta_to_point(base, start_point))
|
||||
}
|
||||
RelativeLineNumbers::Wrapped => {
|
||||
Some(snapshot.relative_wrapped_line_delta_to_point(base, start_point))
|
||||
}
|
||||
});
|
||||
let relative_number = relative_to
|
||||
.filter(|_| relative_line_numbers != RelativeLineNumbers::Disabled)
|
||||
.map(|base| {
|
||||
snapshot.relative_line_delta_to_point(
|
||||
base,
|
||||
start_point,
|
||||
relative_line_numbers == RelativeLineNumbers::Wrapped,
|
||||
)
|
||||
});
|
||||
let number = relative_number
|
||||
.filter(|&delta| delta != 0)
|
||||
.map(|delta| delta.unsigned_abs() as u32)
|
||||
@@ -9055,14 +9055,8 @@ impl Element for EditorElement {
|
||||
let em_advance = window.text_system().em_advance(font_id, font_size).unwrap();
|
||||
let glyph_grid_cell = size(em_advance, line_height);
|
||||
|
||||
let gutter_dimensions = snapshot
|
||||
.gutter_dimensions(
|
||||
font_id,
|
||||
font_size,
|
||||
style,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let gutter_dimensions =
|
||||
snapshot.gutter_dimensions(font_id, font_size, style, window, cx);
|
||||
let text_width = bounds.size.width - gutter_dimensions.width;
|
||||
|
||||
let settings = EditorSettings::get_global(cx);
|
||||
@@ -9276,10 +9270,10 @@ impl Element for EditorElement {
|
||||
};
|
||||
|
||||
let background_color = match diff_status.kind {
|
||||
DiffHunkStatusKind::Added =>
|
||||
cx.theme().colors().version_control_added,
|
||||
DiffHunkStatusKind::Deleted =>
|
||||
cx.theme().colors().version_control_deleted,
|
||||
DiffHunkStatusKind::Added => cx.theme().colors().version_control_added,
|
||||
DiffHunkStatusKind::Deleted => {
|
||||
cx.theme().colors().version_control_deleted
|
||||
}
|
||||
DiffHunkStatusKind::Modified => {
|
||||
debug_panic!("modified diff status for row info");
|
||||
continue;
|
||||
@@ -9423,25 +9417,26 @@ impl Element for EditorElement {
|
||||
);
|
||||
|
||||
// relative rows are based on newest selection, even outside the visible area
|
||||
let relative_row_base = self.editor.update(cx, |editor, cx| {
|
||||
if editor.selections.count()==0 {
|
||||
return None;
|
||||
}
|
||||
let relative_row_base = self.editor.update(cx, |editor, cx| {
|
||||
(editor.selections.count() != 0).then(|| {
|
||||
let newest = editor
|
||||
.selections
|
||||
.newest::<Point>(&editor.display_snapshot(cx));
|
||||
Some(SelectionLayout::new(
|
||||
|
||||
SelectionLayout::new(
|
||||
newest,
|
||||
editor.selections.line_mode(),
|
||||
editor.cursor_offset_on_selection,
|
||||
editor.cursor_shape,
|
||||
&snapshot.display_snapshot,
|
||||
&snapshot,
|
||||
true,
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.head.row())
|
||||
});
|
||||
.head
|
||||
.row()
|
||||
})
|
||||
});
|
||||
|
||||
let mut breakpoint_rows = self.editor.update(cx, |editor, cx| {
|
||||
editor.active_breakpoints(start_row..end_row, window, cx)
|
||||
@@ -9601,9 +9596,10 @@ impl Element for EditorElement {
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
debug_panic!(
|
||||
"skipping recursive prepaint at max depth. renderer widths may be stale."
|
||||
);
|
||||
debug_panic!(concat!(
|
||||
"skipping recursive prepaint at max depth. ",
|
||||
"renderer widths may be stale."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9715,9 +9711,10 @@ impl Element for EditorElement {
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
debug_panic!(
|
||||
"skipping recursive prepaint at max depth. block layout may be stale."
|
||||
);
|
||||
debug_panic!(concat!(
|
||||
"skipping recursive prepaint at max depth. ",
|
||||
"block layout may be stale."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11723,6 +11720,7 @@ mod tests {
|
||||
assert_eq!(relative_rows[&DisplayRow(1)], 2);
|
||||
assert_eq!(relative_rows[&DisplayRow(2)], 1);
|
||||
// current line has no relative number
|
||||
assert!(!relative_rows.contains_key(&DisplayRow(3)));
|
||||
assert_eq!(relative_rows[&DisplayRow(4)], 1);
|
||||
assert_eq!(relative_rows[&DisplayRow(5)], 2);
|
||||
|
||||
@@ -11869,6 +11867,7 @@ mod tests {
|
||||
assert_eq!(relative_rows[&DisplayRow(1)], 2);
|
||||
assert_eq!(relative_rows[&DisplayRow(2)], 1);
|
||||
// current line has no relative number
|
||||
assert!(!relative_rows.contains_key(&DisplayRow(3)));
|
||||
assert_eq!(relative_rows[&DisplayRow(4)], 1);
|
||||
assert_eq!(relative_rows[&DisplayRow(5)], 2);
|
||||
|
||||
@@ -11924,6 +11923,7 @@ mod tests {
|
||||
assert_eq!(relative_rows[&DisplayRow(1)], 2);
|
||||
assert_eq!(relative_rows[&DisplayRow(2)], 1);
|
||||
// current line, even if deleted, has no relative number
|
||||
assert!(!relative_rows.contains_key(&DisplayRow(3)));
|
||||
assert_eq!(relative_rows[&DisplayRow(4)], 1);
|
||||
assert_eq!(relative_rows[&DisplayRow(5)], 2);
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use std::{borrow::Cow, cell::RefCell};
|
||||
use std::{ops::Range, sync::Arc, time::Duration};
|
||||
use std::{path::PathBuf, rc::Rc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Scrollbars, WithScrollbar, prelude::*, theme_is_transparent};
|
||||
use ui::{CopyButton, Scrollbars, WithScrollbar, prelude::*, theme_is_transparent};
|
||||
use url::Url;
|
||||
use util::TryFutureExt;
|
||||
use workspace::{OpenOptions, OpenVisible, Workspace};
|
||||
@@ -994,11 +994,13 @@ impl DiagnosticPopover {
|
||||
.border_color(self.border_color)
|
||||
.rounded_lg()
|
||||
.child(
|
||||
div()
|
||||
h_flex()
|
||||
.id("diagnostic-content-container")
|
||||
.overflow_y_scroll()
|
||||
.gap_1()
|
||||
.items_start()
|
||||
.max_w(max_size.width)
|
||||
.max_h(max_size.height)
|
||||
.overflow_y_scroll()
|
||||
.track_scroll(&self.scroll_handle)
|
||||
.child(
|
||||
MarkdownElement::new(
|
||||
@@ -1021,7 +1023,11 @@ impl DiagnosticPopover {
|
||||
}
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
.child({
|
||||
let message = self.local_diagnostic.diagnostic.message.clone();
|
||||
CopyButton::new(message).tooltip_label("Copy Diagnostic")
|
||||
}),
|
||||
)
|
||||
.custom_scrollbars(
|
||||
Scrollbars::for_settings::<EditorSettings>()
|
||||
|
||||
@@ -164,11 +164,6 @@ pub fn deploy_context_menu(
|
||||
window.focus(&editor.focus_handle(cx), cx);
|
||||
}
|
||||
|
||||
// Don't show context menu for inline editors
|
||||
if !editor.mode().is_full() {
|
||||
return;
|
||||
}
|
||||
|
||||
let display_map = editor.display_snapshot(cx);
|
||||
let source_anchor = display_map.display_point_to_anchor(point, text::Bias::Right);
|
||||
let context_menu = if let Some(custom) = editor.custom_context_menu.take() {
|
||||
@@ -179,6 +174,11 @@ pub fn deploy_context_menu(
|
||||
};
|
||||
menu
|
||||
} else {
|
||||
// Don't show context menu for inline editors (only applies to default menu)
|
||||
if !editor.mode().is_full() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't show the context menu if there isn't a project associated with this editor
|
||||
let Some(project) = editor.project.clone() else {
|
||||
return;
|
||||
|
||||
@@ -23,3 +23,9 @@ pub struct AgentV2FeatureFlag;
|
||||
impl FeatureFlag for AgentV2FeatureFlag {
|
||||
const NAME: &'static str = "agent-v2";
|
||||
}
|
||||
|
||||
pub struct AcpBetaFeatureFlag;
|
||||
|
||||
impl FeatureFlag for AcpBetaFeatureFlag {
|
||||
const NAME: &'static str = "acp-beta";
|
||||
}
|
||||
|
||||
@@ -1760,16 +1760,19 @@ impl PickerDelegate for FileFinderDelegate {
|
||||
menu.context(focus_handle)
|
||||
.action(
|
||||
"Split Left",
|
||||
pane::SplitLeft.boxed_clone(),
|
||||
pane::SplitLeft::default().boxed_clone(),
|
||||
)
|
||||
.action(
|
||||
"Split Right",
|
||||
pane::SplitRight.boxed_clone(),
|
||||
pane::SplitRight::default().boxed_clone(),
|
||||
)
|
||||
.action(
|
||||
"Split Up",
|
||||
pane::SplitUp::default().boxed_clone(),
|
||||
)
|
||||
.action("Split Up", pane::SplitUp.boxed_clone())
|
||||
.action(
|
||||
"Split Down",
|
||||
pane::SplitDown.boxed_clone(),
|
||||
pane::SplitDown::default().boxed_clone(),
|
||||
)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -156,8 +156,16 @@ impl GitRepository for FakeGitRepository {
|
||||
})
|
||||
}
|
||||
|
||||
fn remote_url(&self, _name: &str) -> BoxFuture<'_, Option<String>> {
|
||||
async move { None }.boxed()
|
||||
fn remote_url(&self, name: &str) -> BoxFuture<'_, Option<String>> {
|
||||
let name = name.to_string();
|
||||
let fut = self.with_state_async(false, move |state| {
|
||||
state
|
||||
.remotes
|
||||
.get(&name)
|
||||
.context("remote not found")
|
||||
.cloned()
|
||||
});
|
||||
async move { fut.await.ok() }.boxed()
|
||||
}
|
||||
|
||||
fn diff_tree(&self, _request: DiffTreeType) -> BoxFuture<'_, Result<TreeDiff>> {
|
||||
|
||||
@@ -335,12 +335,11 @@ impl FileHandle for std::fs::File {
|
||||
let mut path_buf = MaybeUninit::<[u8; libc::PATH_MAX as usize]>::uninit();
|
||||
|
||||
let result = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_GETPATH, path_buf.as_mut_ptr()) };
|
||||
if result == -1 {
|
||||
anyhow::bail!("fcntl returned -1".to_string());
|
||||
}
|
||||
anyhow::ensure!(result != -1, "fcntl returned -1");
|
||||
|
||||
// SAFETY: `fcntl` will initialize the path buffer.
|
||||
let c_str = unsafe { CStr::from_ptr(path_buf.as_ptr().cast()) };
|
||||
anyhow::ensure!(!c_str.is_empty(), "Could find a path for the file handle");
|
||||
let path = PathBuf::from(OsStr::from_bytes(c_str.to_bytes()));
|
||||
Ok(path)
|
||||
}
|
||||
@@ -372,12 +371,11 @@ impl FileHandle for std::fs::File {
|
||||
kif.kf_structsize = libc::KINFO_FILE_SIZE;
|
||||
|
||||
let result = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_KINFO, kif.as_mut_ptr()) };
|
||||
if result == -1 {
|
||||
anyhow::bail!("fcntl returned -1".to_string());
|
||||
}
|
||||
anyhow::ensure!(result != -1, "fcntl returned -1");
|
||||
|
||||
// SAFETY: `fcntl` will initialize the kif.
|
||||
let c_str = unsafe { CStr::from_ptr(kif.assume_init().kf_path.as_ptr()) };
|
||||
anyhow::ensure!(!c_str.is_empty(), "Could find a path for the file handle");
|
||||
let path = PathBuf::from(OsStr::from_bytes(c_str.to_bytes()));
|
||||
Ok(path)
|
||||
}
|
||||
@@ -398,18 +396,21 @@ impl FileHandle for std::fs::File {
|
||||
// Query required buffer size (in wide chars)
|
||||
let required_len =
|
||||
unsafe { GetFinalPathNameByHandleW(handle, &mut [], FILE_NAME_NORMALIZED) };
|
||||
if required_len == 0 {
|
||||
anyhow::bail!("GetFinalPathNameByHandleW returned 0 length");
|
||||
}
|
||||
anyhow::ensure!(
|
||||
required_len != 0,
|
||||
"GetFinalPathNameByHandleW returned 0 length"
|
||||
);
|
||||
|
||||
// Allocate buffer and retrieve the path
|
||||
let mut buf: Vec<u16> = vec![0u16; required_len as usize + 1];
|
||||
let written = unsafe { GetFinalPathNameByHandleW(handle, &mut buf, FILE_NAME_NORMALIZED) };
|
||||
if written == 0 {
|
||||
anyhow::bail!("GetFinalPathNameByHandleW failed to write path");
|
||||
}
|
||||
anyhow::ensure!(
|
||||
written != 0,
|
||||
"GetFinalPathNameByHandleW failed to write path"
|
||||
);
|
||||
|
||||
let os_str: OsString = OsString::from_wide(&buf[..written as usize]);
|
||||
anyhow::ensure!(!os_str.is_empty(), "Could find a path for the file handle");
|
||||
Ok(PathBuf::from(os_str))
|
||||
}
|
||||
}
|
||||
@@ -1857,6 +1858,18 @@ impl FakeFs {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn set_remote_for_repo(
|
||||
&self,
|
||||
dot_git: &Path,
|
||||
name: impl Into<String>,
|
||||
url: impl Into<String>,
|
||||
) {
|
||||
self.with_git_state(dot_git, true, |state| {
|
||||
state.remotes.insert(name.into(), url.into());
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn insert_branches(&self, dot_git: &Path, branches: &[&str]) {
|
||||
self.with_git_state(dot_git, true, |state| {
|
||||
if let Some(first) = branches.first()
|
||||
|
||||
@@ -76,7 +76,7 @@ impl EventStream {
|
||||
cf::CFRelease(cf_path);
|
||||
cf::CFRelease(cf_url);
|
||||
} else {
|
||||
log::error!("Failed to create CFURL for path: {}", path.display());
|
||||
log::error!("Failed to create CFURL for path: {path:?}");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use project::{git_store::Repository, project_settings::ProjectSettings};
|
||||
use settings::Settings as _;
|
||||
use theme::ThemeSettings;
|
||||
use time::OffsetDateTime;
|
||||
use ui::{ContextMenu, Divider, prelude::*, tooltip_container};
|
||||
use ui::{ContextMenu, CopyButton, Divider, prelude::*, tooltip_container};
|
||||
use workspace::Workspace;
|
||||
|
||||
const GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED: usize = 20;
|
||||
@@ -335,18 +335,10 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.child(Divider::vertical())
|
||||
.child(
|
||||
IconButton::new("copy-sha-button", IconName::Copy)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
cx.write_to_clipboard(
|
||||
ClipboardItem::new_string(
|
||||
sha.to_string(),
|
||||
),
|
||||
)
|
||||
}),
|
||||
CopyButton::new(sha.to_string())
|
||||
.tooltip_label("Copy SHA"),
|
||||
),
|
||||
),
|
||||
),
|
||||
|
||||
@@ -5,7 +5,7 @@ use git::blame::BlameEntry;
|
||||
use git::repository::CommitSummary;
|
||||
use git::{GitRemote, commit::ParsedCommitMessage};
|
||||
use gpui::{
|
||||
App, Asset, ClipboardItem, Element, Entity, MouseButton, ParentElement, Render, ScrollHandle,
|
||||
App, Asset, Element, Entity, MouseButton, ParentElement, Render, ScrollHandle,
|
||||
StatefulInteractiveElement, WeakEntity, prelude::*,
|
||||
};
|
||||
use markdown::{Markdown, MarkdownElement};
|
||||
@@ -14,7 +14,7 @@ use settings::Settings;
|
||||
use std::hash::Hash;
|
||||
use theme::ThemeSettings;
|
||||
use time::{OffsetDateTime, UtcOffset};
|
||||
use ui::{Avatar, Divider, IconButtonShape, prelude::*, tooltip_container};
|
||||
use ui::{Avatar, CopyButton, Divider, prelude::*, tooltip_container};
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -315,8 +315,8 @@ impl Render for CommitTooltip {
|
||||
cx.open_url(pr.url.as_str())
|
||||
}),
|
||||
)
|
||||
.child(Divider::vertical())
|
||||
})
|
||||
.child(Divider::vertical())
|
||||
.child(
|
||||
Button::new(
|
||||
"commit-sha-button",
|
||||
@@ -342,18 +342,8 @@ impl Render for CommitTooltip {
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("copy-sha-button", IconName::Copy)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
cx.write_to_clipboard(
|
||||
ClipboardItem::new_string(full_sha.clone()),
|
||||
)
|
||||
}),
|
||||
),
|
||||
.child(Divider::vertical())
|
||||
.child(CopyButton::new(full_sha).tooltip_label("Copy SHA")),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -8,9 +8,9 @@ use git::{
|
||||
parse_git_remote_url,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, App, AppContext as _, AsyncApp, AsyncWindowContext, Context, Element, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ParentElement,
|
||||
PromptLevel, Render, Styled, Task, WeakEntity, Window, actions,
|
||||
AnyElement, App, AppContext as _, AsyncApp, AsyncWindowContext, ClipboardItem, Context,
|
||||
Element, Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement,
|
||||
ParentElement, PromptLevel, Render, Styled, Task, WeakEntity, Window, actions,
|
||||
};
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, OffsetRangeExt as _,
|
||||
@@ -24,7 +24,7 @@ use std::{
|
||||
sync::Arc,
|
||||
};
|
||||
use theme::ActiveTheme;
|
||||
use ui::{DiffStat, Tooltip, prelude::*};
|
||||
use ui::{ButtonLike, DiffStat, Tooltip, prelude::*};
|
||||
use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff};
|
||||
use workspace::item::TabTooltipContent;
|
||||
use workspace::{
|
||||
@@ -383,6 +383,7 @@ impl CommitView {
|
||||
fn render_header(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let commit = &self.commit;
|
||||
let author_name = commit.author_name.clone();
|
||||
let commit_sha = commit.sha.clone();
|
||||
let commit_date = time::OffsetDateTime::from_unix_timestamp(commit.commit_timestamp)
|
||||
.unwrap_or_else(|_| time::OffsetDateTime::now_utc());
|
||||
let local_offset = time::UtcOffset::current_local_offset().unwrap_or(time::UtcOffset::UTC);
|
||||
@@ -429,6 +430,19 @@ impl CommitView {
|
||||
.full_width()
|
||||
});
|
||||
|
||||
let clipboard_has_link = cx
|
||||
.read_from_clipboard()
|
||||
.and_then(|entry| entry.text())
|
||||
.map_or(false, |clipboard_text| {
|
||||
clipboard_text.trim() == commit_sha.as_ref()
|
||||
});
|
||||
|
||||
let (copy_icon, copy_icon_color) = if clipboard_has_link {
|
||||
(IconName::Check, Color::Success)
|
||||
} else {
|
||||
(IconName::Copy, Color::Muted)
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
@@ -454,13 +468,47 @@ impl CommitView {
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Label::new(author_name).color(Color::Default))
|
||||
.child(
|
||||
Label::new(format!("Commit:{}", commit.sha))
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small)
|
||||
.truncate()
|
||||
.buffer_font(cx),
|
||||
),
|
||||
.child({
|
||||
ButtonLike::new("sha")
|
||||
.child(
|
||||
h_flex()
|
||||
.group("sha_btn")
|
||||
.size_full()
|
||||
.max_w_32()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
Label::new(commit_sha.clone())
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small)
|
||||
.truncate()
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(
|
||||
div().visible_on_hover("sha_btn").child(
|
||||
Icon::new(copy_icon)
|
||||
.color(copy_icon_color)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
.tooltip({
|
||||
let commit_sha = commit_sha.clone();
|
||||
move |_, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Copy Commit SHA",
|
||||
None,
|
||||
commit_sha.clone(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(
|
||||
commit_sha.to_string(),
|
||||
));
|
||||
})
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
|
||||
@@ -3638,7 +3638,7 @@ impl GitPanel {
|
||||
self.entry_count += 1;
|
||||
let is_staging_or_staged = GitPanel::stage_status_for_entry(status_entry, repo)
|
||||
.as_bool()
|
||||
.unwrap_or(false);
|
||||
.unwrap_or(true);
|
||||
|
||||
if repo.had_conflict_on_last_merge_head_change(&status_entry.repo_path) {
|
||||
self.conflicted_count += 1;
|
||||
|
||||
@@ -2154,7 +2154,6 @@ impl Interactivity {
|
||||
|| cx.active_drag.is_some() && !self.drag_over_styles.is_empty()
|
||||
{
|
||||
let hitbox = hitbox.clone();
|
||||
let was_hovered = hitbox.is_hovered(window);
|
||||
let hover_state = self.hover_style.as_ref().and_then(|_| {
|
||||
element_state
|
||||
.as_ref()
|
||||
@@ -2162,8 +2161,12 @@ impl Interactivity {
|
||||
.cloned()
|
||||
});
|
||||
let current_view = window.current_view();
|
||||
|
||||
window.on_mouse_event(move |_: &MouseMoveEvent, phase, window, cx| {
|
||||
let hovered = hitbox.is_hovered(window);
|
||||
let was_hovered = hover_state
|
||||
.as_ref()
|
||||
.is_some_and(|state| state.borrow().element);
|
||||
if phase == DispatchPhase::Capture && hovered != was_hovered {
|
||||
if let Some(hover_state) = &hover_state {
|
||||
hover_state.borrow_mut().element = hovered;
|
||||
@@ -2179,12 +2182,13 @@ impl Interactivity {
|
||||
.as_ref()
|
||||
.and_then(|element| element.hover_state.as_ref())
|
||||
.cloned();
|
||||
|
||||
let was_group_hovered = group_hitbox_id.is_hovered(window);
|
||||
let current_view = window.current_view();
|
||||
|
||||
window.on_mouse_event(move |_: &MouseMoveEvent, phase, window, cx| {
|
||||
let group_hovered = group_hitbox_id.is_hovered(window);
|
||||
let was_group_hovered = hover_state
|
||||
.as_ref()
|
||||
.is_some_and(|state| state.borrow().group);
|
||||
if phase == DispatchPhase::Capture && group_hovered != was_group_hovered {
|
||||
if let Some(hover_state) = &hover_state {
|
||||
hover_state.borrow_mut().group = group_hovered;
|
||||
|
||||
@@ -46,9 +46,9 @@ pub unsafe fn new_renderer(
|
||||
_native_window: *mut c_void,
|
||||
_native_view: *mut c_void,
|
||||
_bounds: crate::Size<f32>,
|
||||
_transparent: bool,
|
||||
transparent: bool,
|
||||
) -> Renderer {
|
||||
MetalRenderer::new(context)
|
||||
MetalRenderer::new(context, transparent)
|
||||
}
|
||||
|
||||
pub(crate) struct InstanceBufferPool {
|
||||
@@ -128,7 +128,7 @@ pub struct PathRasterizationVertex {
|
||||
}
|
||||
|
||||
impl MetalRenderer {
|
||||
pub fn new(instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>) -> Self {
|
||||
pub fn new(instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>, transparent: bool) -> Self {
|
||||
// Prefer low‐power integrated GPUs on Intel Mac. On Apple
|
||||
// Silicon, there is only ever one GPU, so this is equivalent to
|
||||
// `metal::Device::system_default()`.
|
||||
@@ -152,7 +152,9 @@ impl MetalRenderer {
|
||||
let layer = metal::MetalLayer::new();
|
||||
layer.set_device(&device);
|
||||
layer.set_pixel_format(MTLPixelFormat::BGRA8Unorm);
|
||||
layer.set_opaque(false);
|
||||
// Support direct-to-display rendering if the window is not transparent
|
||||
// https://developer.apple.com/documentation/metal/managing-your-game-window-for-metal-in-macos
|
||||
layer.set_opaque(!transparent);
|
||||
layer.set_maximum_drawable_count(3);
|
||||
unsafe {
|
||||
let _: () = msg_send![&*layer, setAllowsNextDrawableTimeout: NO];
|
||||
@@ -352,8 +354,8 @@ impl MetalRenderer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_transparency(&self, _transparent: bool) {
|
||||
// todo(mac)?
|
||||
pub fn update_transparency(&self, transparent: bool) {
|
||||
self.layer.set_opaque(!transparent);
|
||||
}
|
||||
|
||||
pub fn destroy(&self) {
|
||||
|
||||
@@ -42,7 +42,7 @@ impl WindowsWindowInner {
|
||||
let handled = match msg {
|
||||
// eagerly activate the window, so calls to `active_window` will work correctly
|
||||
WM_MOUSEACTIVATE => {
|
||||
unsafe { SetActiveWindow(handle).log_err() };
|
||||
unsafe { SetActiveWindow(handle).ok() };
|
||||
None
|
||||
}
|
||||
WM_ACTIVATE => self.handle_activate_msg(wparam),
|
||||
|
||||
@@ -740,8 +740,8 @@ impl PlatformWindow for WindowsWindow {
|
||||
ShowWindowAsync(hwnd, SW_RESTORE).ok().log_err();
|
||||
}
|
||||
|
||||
SetActiveWindow(hwnd).log_err();
|
||||
SetFocus(Some(hwnd)).log_err();
|
||||
SetActiveWindow(hwnd).ok();
|
||||
SetFocus(Some(hwnd)).ok();
|
||||
}
|
||||
|
||||
// premium ragebait by windows, this is needed because the window
|
||||
|
||||
@@ -20,6 +20,7 @@ dap.workspace = true
|
||||
extension.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
lsp.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, WeakEntity};
|
||||
use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, Task, WeakEntity};
|
||||
use language::{LanguageRegistry, language_settings::all_language_settings};
|
||||
use project::LspStore;
|
||||
use lsp::LanguageServerBinaryOptions;
|
||||
use project::{LspStore, lsp_store::LocalLspAdapterDelegate};
|
||||
use settings::LSP_SETTINGS_SCHEMA_URL_PREFIX;
|
||||
use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields};
|
||||
|
||||
// Origin: https://github.com/SchemaStore/schemastore
|
||||
@@ -75,23 +77,28 @@ fn handle_schema_request(
|
||||
lsp_store: Entity<LspStore>,
|
||||
uri: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
let languages = lsp_store.read_with(cx, |lsp_store, _| lsp_store.languages.clone())?;
|
||||
let schema = resolve_schema_request(&languages, uri, cx)?;
|
||||
serde_json::to_string(&schema).context("Failed to serialize schema")
|
||||
) -> Task<Result<String>> {
|
||||
let languages = lsp_store.read_with(cx, |lsp_store, _| lsp_store.languages.clone());
|
||||
cx.spawn(async move |cx| {
|
||||
let languages = languages?;
|
||||
let schema = resolve_schema_request(&languages, lsp_store, uri, cx).await?;
|
||||
serde_json::to_string(&schema).context("Failed to serialize schema")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_schema_request(
|
||||
pub async fn resolve_schema_request(
|
||||
languages: &Arc<LanguageRegistry>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
uri: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<serde_json::Value> {
|
||||
let path = uri.strip_prefix("zed://schemas/").context("Invalid URI")?;
|
||||
resolve_schema_request_inner(languages, path, cx)
|
||||
resolve_schema_request_inner(languages, lsp_store, path, cx).await
|
||||
}
|
||||
|
||||
pub fn resolve_schema_request_inner(
|
||||
pub async fn resolve_schema_request_inner(
|
||||
languages: &Arc<LanguageRegistry>,
|
||||
lsp_store: Entity<LspStore>,
|
||||
path: &str,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<serde_json::Value> {
|
||||
@@ -99,37 +106,121 @@ pub fn resolve_schema_request_inner(
|
||||
let schema_name = schema_name.unwrap_or(path);
|
||||
|
||||
let schema = match schema_name {
|
||||
"settings" => cx.update(|cx| {
|
||||
let font_names = &cx.text_system().all_font_names();
|
||||
let language_names = &languages
|
||||
.language_names()
|
||||
"settings" if rest.is_some_and(|r| r.starts_with("lsp/")) => {
|
||||
let lsp_name = rest
|
||||
.and_then(|r| {
|
||||
r.strip_prefix(
|
||||
LSP_SETTINGS_SCHEMA_URL_PREFIX
|
||||
.strip_prefix("zed://schemas/settings/")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.context("Invalid LSP schema path")?;
|
||||
|
||||
let adapter = languages
|
||||
.all_lsp_adapters()
|
||||
.into_iter()
|
||||
.map(|name| name.to_string())
|
||||
.find(|adapter| adapter.name().as_ref() as &str == lsp_name)
|
||||
.with_context(|| format!("LSP adapter not found: {}", lsp_name))?;
|
||||
|
||||
let delegate = cx
|
||||
.update(|inner_cx| {
|
||||
lsp_store.update(inner_cx, |lsp_store, inner_cx| {
|
||||
let Some(local) = lsp_store.as_local() else {
|
||||
return None;
|
||||
};
|
||||
let Some(worktree) = local.worktree_store.read(inner_cx).worktrees().next()
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
Some(LocalLspAdapterDelegate::from_local_lsp(
|
||||
local, &worktree, inner_cx,
|
||||
))
|
||||
})
|
||||
})?
|
||||
.context(concat!(
|
||||
"Failed to create adapter delegate - ",
|
||||
"either LSP store is not in local mode or no worktree is available"
|
||||
))?;
|
||||
|
||||
let adapter_for_schema = adapter.clone();
|
||||
|
||||
let binary = adapter
|
||||
.get_language_server_command(
|
||||
delegate,
|
||||
None,
|
||||
LanguageServerBinaryOptions {
|
||||
allow_path_lookup: true,
|
||||
allow_binary_download: false,
|
||||
pre_release: false,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.await
|
||||
.0
|
||||
.with_context(|| {
|
||||
format!(
|
||||
concat!(
|
||||
"Failed to find language server {} ",
|
||||
"to generate initialization params schema"
|
||||
),
|
||||
lsp_name
|
||||
)
|
||||
})?;
|
||||
|
||||
adapter_for_schema
|
||||
.adapter
|
||||
.clone()
|
||||
.initialization_options_schema(&binary)
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
})
|
||||
})
|
||||
}
|
||||
"settings" => {
|
||||
let lsp_adapter_names = languages
|
||||
.all_lsp_adapters()
|
||||
.into_iter()
|
||||
.map(|adapter| adapter.name().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut icon_theme_names = vec![];
|
||||
let mut theme_names = vec![];
|
||||
if let Some(registry) = theme::ThemeRegistry::try_global(cx) {
|
||||
icon_theme_names.extend(
|
||||
registry
|
||||
.list_icon_themes()
|
||||
.into_iter()
|
||||
.map(|icon_theme| icon_theme.name),
|
||||
);
|
||||
theme_names.extend(registry.list_names());
|
||||
}
|
||||
let icon_theme_names = icon_theme_names.as_slice();
|
||||
let theme_names = theme_names.as_slice();
|
||||
cx.update(|cx| {
|
||||
let font_names = &cx.text_system().all_font_names();
|
||||
let language_names = &languages
|
||||
.language_names()
|
||||
.into_iter()
|
||||
.map(|name| name.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.global::<settings::SettingsStore>().json_schema(
|
||||
&settings::SettingsJsonSchemaParams {
|
||||
language_names,
|
||||
font_names,
|
||||
theme_names,
|
||||
icon_theme_names,
|
||||
},
|
||||
)
|
||||
})?,
|
||||
let mut icon_theme_names = vec![];
|
||||
let mut theme_names = vec![];
|
||||
if let Some(registry) = theme::ThemeRegistry::try_global(cx) {
|
||||
icon_theme_names.extend(
|
||||
registry
|
||||
.list_icon_themes()
|
||||
.into_iter()
|
||||
.map(|icon_theme| icon_theme.name),
|
||||
);
|
||||
theme_names.extend(registry.list_names());
|
||||
}
|
||||
let icon_theme_names = icon_theme_names.as_slice();
|
||||
let theme_names = theme_names.as_slice();
|
||||
|
||||
cx.global::<settings::SettingsStore>().json_schema(
|
||||
&settings::SettingsJsonSchemaParams {
|
||||
language_names,
|
||||
font_names,
|
||||
theme_names,
|
||||
icon_theme_names,
|
||||
lsp_adapter_names: &lsp_adapter_names,
|
||||
},
|
||||
)
|
||||
})?
|
||||
}
|
||||
"keymap" => cx.update(settings::KeymapFile::generate_json_schema_for_registered_actions)?,
|
||||
"action" => {
|
||||
let normalized_action_name = rest.context("No Action name provided")?;
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
},
|
||||
task_context::RunnableRange,
|
||||
text_diff::text_diff,
|
||||
unified_diff,
|
||||
unified_diff_with_offsets,
|
||||
};
|
||||
pub use crate::{
|
||||
Grammar, Language, LanguageRegistry,
|
||||
@@ -773,7 +773,11 @@ pub struct EditPreview {
|
||||
}
|
||||
|
||||
impl EditPreview {
|
||||
pub fn as_unified_diff(&self, edits: &[(Range<Anchor>, impl AsRef<str>)]) -> Option<String> {
|
||||
pub fn as_unified_diff(
|
||||
&self,
|
||||
file: Option<&Arc<dyn File>>,
|
||||
edits: &[(Range<Anchor>, impl AsRef<str>)],
|
||||
) -> Option<String> {
|
||||
let (first, _) = edits.first()?;
|
||||
let (last, _) = edits.last()?;
|
||||
|
||||
@@ -788,7 +792,7 @@ impl EditPreview {
|
||||
let old_end = Point::new(old_end.row + 4, 0).min(self.old_snapshot.max_point());
|
||||
let new_end = Point::new(new_end.row + 4, 0).min(self.applied_edits_snapshot.max_point());
|
||||
|
||||
Some(unified_diff(
|
||||
let diff_body = unified_diff_with_offsets(
|
||||
&self
|
||||
.old_snapshot
|
||||
.text_for_range(start..old_end)
|
||||
@@ -797,7 +801,17 @@ impl EditPreview {
|
||||
.applied_edits_snapshot
|
||||
.text_for_range(start..new_end)
|
||||
.collect::<String>(),
|
||||
))
|
||||
start.row,
|
||||
start.row,
|
||||
);
|
||||
|
||||
let path = file.map(|f| f.path().as_unix_str());
|
||||
let header = match path {
|
||||
Some(p) => format!("--- a/{}\n+++ b/{}\n", p, p),
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
Some(format!("{}{}", header, diff_body))
|
||||
}
|
||||
|
||||
pub fn highlight_edits(
|
||||
|
||||
@@ -67,7 +67,7 @@ use task::RunnableTag;
|
||||
pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
|
||||
pub use text_diff::{
|
||||
DiffOptions, apply_diff_patch, line_diff, text_diff, text_diff_with_options, unified_diff,
|
||||
word_diff_ranges,
|
||||
unified_diff_with_offsets, word_diff_ranges,
|
||||
};
|
||||
use theme::SyntaxTheme;
|
||||
pub use toolchain::{
|
||||
@@ -461,6 +461,14 @@ pub trait LspAdapter: 'static + Send + Sync + DynLspInstaller {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema of the initialization_options for the language server.
|
||||
async fn initialization_options_schema(
|
||||
self: Arc<Self>,
|
||||
_language_server_binary: &LanguageServerBinary,
|
||||
) -> Option<serde_json::Value> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn workspace_configuration(
|
||||
self: Arc<Self>,
|
||||
_: &Arc<dyn LspAdapterDelegate>,
|
||||
|
||||
@@ -392,6 +392,7 @@ pub struct EditPredictionSettings {
|
||||
/// Whether edit predictions are enabled in the assistant panel.
|
||||
/// This setting has no effect if globally disabled.
|
||||
pub enabled_in_text_threads: bool,
|
||||
pub examples_dir: Option<Arc<Path>>,
|
||||
}
|
||||
|
||||
impl EditPredictionSettings {
|
||||
@@ -699,6 +700,7 @@ impl settings::Settings for AllLanguageSettings {
|
||||
copilot: copilot_settings,
|
||||
codestral: codestral_settings,
|
||||
enabled_in_text_threads,
|
||||
examples_dir: edit_predictions.examples_dir,
|
||||
},
|
||||
defaults: default_language_settings,
|
||||
languages,
|
||||
|
||||
@@ -1,25 +1,139 @@
|
||||
use crate::{CharClassifier, CharKind, CharScopeContext, LanguageScope};
|
||||
use anyhow::{Context, anyhow};
|
||||
use imara_diff::{
|
||||
Algorithm, UnifiedDiffBuilder, diff,
|
||||
intern::{InternedInput, Token},
|
||||
Algorithm, Sink, diff,
|
||||
intern::{InternedInput, Interner, Token},
|
||||
sources::lines_with_terminator,
|
||||
};
|
||||
use std::{iter, ops::Range, sync::Arc};
|
||||
use std::{fmt::Write, iter, ops::Range, sync::Arc};
|
||||
|
||||
const MAX_WORD_DIFF_LEN: usize = 512;
|
||||
const MAX_WORD_DIFF_LINE_COUNT: usize = 8;
|
||||
|
||||
/// Computes a diff between two strings, returning a unified diff string.
|
||||
pub fn unified_diff(old_text: &str, new_text: &str) -> String {
|
||||
unified_diff_with_offsets(old_text, new_text, 0, 0)
|
||||
}
|
||||
|
||||
/// Computes a diff between two strings, returning a unified diff string with
|
||||
/// hunk headers adjusted to reflect the given starting line numbers (1-indexed).
|
||||
pub fn unified_diff_with_offsets(
|
||||
old_text: &str,
|
||||
new_text: &str,
|
||||
old_start_line: u32,
|
||||
new_start_line: u32,
|
||||
) -> String {
|
||||
let input = InternedInput::new(old_text, new_text);
|
||||
diff(
|
||||
Algorithm::Histogram,
|
||||
&input,
|
||||
UnifiedDiffBuilder::new(&input),
|
||||
OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line),
|
||||
)
|
||||
}
|
||||
|
||||
/// A unified diff builder that applies line number offsets to hunk headers.
|
||||
struct OffsetUnifiedDiffBuilder<'a> {
|
||||
before: &'a [Token],
|
||||
after: &'a [Token],
|
||||
interner: &'a Interner<&'a str>,
|
||||
|
||||
pos: u32,
|
||||
before_hunk_start: u32,
|
||||
after_hunk_start: u32,
|
||||
before_hunk_len: u32,
|
||||
after_hunk_len: u32,
|
||||
|
||||
old_line_offset: u32,
|
||||
new_line_offset: u32,
|
||||
|
||||
buffer: String,
|
||||
dst: String,
|
||||
}
|
||||
|
||||
impl<'a> OffsetUnifiedDiffBuilder<'a> {
|
||||
fn new(input: &'a InternedInput<&'a str>, old_line_offset: u32, new_line_offset: u32) -> Self {
|
||||
Self {
|
||||
before_hunk_start: 0,
|
||||
after_hunk_start: 0,
|
||||
before_hunk_len: 0,
|
||||
after_hunk_len: 0,
|
||||
old_line_offset,
|
||||
new_line_offset,
|
||||
buffer: String::with_capacity(8),
|
||||
dst: String::new(),
|
||||
interner: &input.interner,
|
||||
before: &input.before,
|
||||
after: &input.after,
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
|
||||
for &token in tokens {
|
||||
writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) {
|
||||
if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let end = (self.pos + 3).min(self.before.len() as u32);
|
||||
self.update_pos(end, end);
|
||||
|
||||
writeln!(
|
||||
&mut self.dst,
|
||||
"@@ -{},{} +{},{} @@",
|
||||
self.before_hunk_start + 1 + self.old_line_offset,
|
||||
self.before_hunk_len,
|
||||
self.after_hunk_start + 1 + self.new_line_offset,
|
||||
self.after_hunk_len,
|
||||
)
|
||||
.unwrap();
|
||||
write!(&mut self.dst, "{}", &self.buffer).unwrap();
|
||||
self.buffer.clear();
|
||||
self.before_hunk_len = 0;
|
||||
self.after_hunk_len = 0;
|
||||
}
|
||||
|
||||
fn update_pos(&mut self, print_to: u32, move_to: u32) {
|
||||
self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
|
||||
let len = print_to - self.pos;
|
||||
self.pos = move_to;
|
||||
self.before_hunk_len += len;
|
||||
self.after_hunk_len += len;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink for OffsetUnifiedDiffBuilder<'_> {
|
||||
type Out = String;
|
||||
|
||||
fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
|
||||
if before.start - self.pos > 6 {
|
||||
self.flush();
|
||||
}
|
||||
if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
|
||||
self.pos = before.start.saturating_sub(3);
|
||||
self.before_hunk_start = self.pos;
|
||||
self.after_hunk_start = after.start.saturating_sub(3);
|
||||
}
|
||||
self.update_pos(before.start, before.end);
|
||||
self.before_hunk_len += before.end - before.start;
|
||||
self.after_hunk_len += after.end - after.start;
|
||||
self.print_tokens(
|
||||
&self.before[before.start as usize..before.end as usize],
|
||||
'-',
|
||||
);
|
||||
self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
|
||||
}
|
||||
|
||||
fn finish(mut self) -> Self::Out {
|
||||
self.flush();
|
||||
self.dst
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a diff between two strings, returning a vector of old and new row
|
||||
/// ranges.
|
||||
pub fn line_diff(old_text: &str, new_text: &str) -> Vec<(Range<u32>, Range<u32>)> {
|
||||
@@ -327,4 +441,30 @@ mod tests {
|
||||
let patch = unified_diff(old_text, new_text);
|
||||
assert_eq!(apply_diff_patch(old_text, &patch).unwrap(), new_text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unified_diff_with_offsets() {
|
||||
let old_text = "foo\nbar\nbaz\n";
|
||||
let new_text = "foo\nBAR\nbaz\n";
|
||||
|
||||
let expected_diff_body = " foo\n-bar\n+BAR\n baz\n";
|
||||
|
||||
let diff_no_offset = unified_diff(old_text, new_text);
|
||||
assert_eq!(
|
||||
diff_no_offset,
|
||||
format!("@@ -1,3 +1,3 @@\n{}", expected_diff_body)
|
||||
);
|
||||
|
||||
let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 9, 11);
|
||||
assert_eq!(
|
||||
diff_with_offset,
|
||||
format!("@@ -10,3 +12,3 @@\n{}", expected_diff_body)
|
||||
);
|
||||
|
||||
let diff_with_offset = unified_diff_with_offsets(old_text, new_text, 99, 104);
|
||||
assert_eq!(
|
||||
diff_with_offset,
|
||||
format!("@@ -100,3 +105,3 @@\n{}", expected_diff_body)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
name = "JSONC"
|
||||
grammar = "jsonc"
|
||||
path_suffixes = ["jsonc", "bun.lock", "tsconfig.json", "pyrightconfig.json"]
|
||||
path_suffixes = ["jsonc", "bun.lock", "devcontainer.json", "pyrightconfig.json", "tsconfig.json"]
|
||||
line_comments = ["// "]
|
||||
autoclose_before = ",]}"
|
||||
brackets = [
|
||||
|
||||
@@ -22,7 +22,7 @@ rewrap_prefixes = [
|
||||
]
|
||||
unordered_list = ["- ", "* ", "+ "]
|
||||
ordered_list = [{ pattern = "(\\d+)\\. ", format = "{1}. " }]
|
||||
task_list = { prefixes = ["- [ ] ", "- [x] "], continuation = "- [ ] " }
|
||||
task_list = { prefixes = ["- [ ] ", "- [x] ", "- [X] "], continuation = "- [ ] " }
|
||||
|
||||
auto_indent_on_paste = false
|
||||
auto_indent_using_last_non_empty_line = false
|
||||
|
||||
@@ -26,6 +26,7 @@ use settings::Settings;
|
||||
use smol::lock::OnceCell;
|
||||
use std::cmp::{Ordering, Reverse};
|
||||
use std::env::consts;
|
||||
use std::process::Stdio;
|
||||
use terminal::terminal_settings::TerminalSettings;
|
||||
use util::command::new_smol_command;
|
||||
use util::fs::{make_file_executable, remove_matching};
|
||||
@@ -2173,6 +2174,119 @@ pub(crate) struct RuffLspAdapter {
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
impl RuffLspAdapter {
|
||||
fn convert_ruff_schema(raw_schema: &serde_json::Value) -> serde_json::Value {
|
||||
let Some(schema_object) = raw_schema.as_object() else {
|
||||
return raw_schema.clone();
|
||||
};
|
||||
|
||||
let mut root_properties = serde_json::Map::new();
|
||||
|
||||
for (key, value) in schema_object {
|
||||
let parts: Vec<&str> = key.split('.').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut current = &mut root_properties;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
let is_last = i == parts.len() - 1;
|
||||
|
||||
if is_last {
|
||||
let mut schema_entry = serde_json::Map::new();
|
||||
|
||||
if let Some(doc) = value.get("doc").and_then(|d| d.as_str()) {
|
||||
schema_entry.insert(
|
||||
"markdownDescription".to_string(),
|
||||
serde_json::Value::String(doc.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(default_val) = value.get("default") {
|
||||
schema_entry.insert("default".to_string(), default_val.clone());
|
||||
}
|
||||
|
||||
if let Some(value_type) = value.get("value_type").and_then(|v| v.as_str()) {
|
||||
if value_type.contains('|') {
|
||||
let enum_values: Vec<serde_json::Value> = value_type
|
||||
.split('|')
|
||||
.map(|s| s.trim().trim_matches('"'))
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| serde_json::Value::String(s.to_string()))
|
||||
.collect();
|
||||
|
||||
if !enum_values.is_empty() {
|
||||
schema_entry
|
||||
.insert("type".to_string(), serde_json::json!("string"));
|
||||
schema_entry.insert(
|
||||
"enum".to_string(),
|
||||
serde_json::Value::Array(enum_values),
|
||||
);
|
||||
}
|
||||
} else if value_type.starts_with("list[") {
|
||||
schema_entry.insert("type".to_string(), serde_json::json!("array"));
|
||||
if let Some(item_type) = value_type
|
||||
.strip_prefix("list[")
|
||||
.and_then(|s| s.strip_suffix(']'))
|
||||
{
|
||||
let json_type = match item_type {
|
||||
"str" => "string",
|
||||
"int" => "integer",
|
||||
"bool" => "boolean",
|
||||
_ => "string",
|
||||
};
|
||||
schema_entry.insert(
|
||||
"items".to_string(),
|
||||
serde_json::json!({"type": json_type}),
|
||||
);
|
||||
}
|
||||
} else if value_type.starts_with("dict[") {
|
||||
schema_entry.insert("type".to_string(), serde_json::json!("object"));
|
||||
} else {
|
||||
let json_type = match value_type {
|
||||
"bool" => "boolean",
|
||||
"int" | "usize" => "integer",
|
||||
"str" => "string",
|
||||
_ => "string",
|
||||
};
|
||||
schema_entry.insert(
|
||||
"type".to_string(),
|
||||
serde_json::Value::String(json_type.to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
current.insert(part.to_string(), serde_json::Value::Object(schema_entry));
|
||||
} else {
|
||||
let next_current = current
|
||||
.entry(part.to_string())
|
||||
.or_insert_with(|| {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
})
|
||||
.as_object_mut()
|
||||
.expect("should be an object")
|
||||
.entry("properties")
|
||||
.or_insert_with(|| serde_json::json!({}))
|
||||
.as_object_mut()
|
||||
.expect("properties should be an object");
|
||||
|
||||
current = next_current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": root_properties
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
impl RuffLspAdapter {
|
||||
const GITHUB_ASSET_KIND: AssetKind = AssetKind::TarGz;
|
||||
@@ -2225,6 +2339,36 @@ impl LspAdapter for RuffLspAdapter {
|
||||
fn name(&self) -> LanguageServerName {
|
||||
Self::SERVER_NAME
|
||||
}
|
||||
|
||||
async fn initialization_options_schema(
|
||||
self: Arc<Self>,
|
||||
language_server_binary: &LanguageServerBinary,
|
||||
) -> Option<serde_json::Value> {
|
||||
let mut command = util::command::new_smol_command(&language_server_binary.path);
|
||||
command
|
||||
.args(&["config", "--output-format", "json"])
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
let cmd = command
|
||||
.spawn()
|
||||
.map_err(|e| log::debug!("failed to spawn command {command:?}: {e}"))
|
||||
.ok()?;
|
||||
let output = cmd
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| log::debug!("failed to execute command {command:?}: {e}"))
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let raw_schema: serde_json::Value = serde_json::from_slice(output.stdout.as_slice())
|
||||
.map_err(|e| log::debug!("failed to parse ruff's JSON schema output: {e}"))
|
||||
.ok()?;
|
||||
|
||||
let converted_schema = Self::convert_ruff_schema(&raw_schema);
|
||||
Some(converted_schema)
|
||||
}
|
||||
}
|
||||
|
||||
impl LspInstaller for RuffLspAdapter {
|
||||
@@ -2568,4 +2712,149 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_ruff_schema() {
|
||||
use super::RuffLspAdapter;
|
||||
|
||||
let raw_schema = serde_json::json!({
|
||||
"line-length": {
|
||||
"doc": "The line length to use when enforcing long-lines violations",
|
||||
"default": "88",
|
||||
"value_type": "int",
|
||||
"scope": null,
|
||||
"example": "line-length = 120",
|
||||
"deprecated": null
|
||||
},
|
||||
"lint.select": {
|
||||
"doc": "A list of rule codes or prefixes to enable",
|
||||
"default": "[\"E4\", \"E7\", \"E9\", \"F\"]",
|
||||
"value_type": "list[RuleSelector]",
|
||||
"scope": null,
|
||||
"example": "select = [\"E4\", \"E7\", \"E9\", \"F\", \"B\", \"Q\"]",
|
||||
"deprecated": null
|
||||
},
|
||||
"lint.isort.case-sensitive": {
|
||||
"doc": "Sort imports taking into account case sensitivity.",
|
||||
"default": "false",
|
||||
"value_type": "bool",
|
||||
"scope": null,
|
||||
"example": "case-sensitive = true",
|
||||
"deprecated": null
|
||||
},
|
||||
"format.quote-style": {
|
||||
"doc": "Configures the preferred quote character for strings.",
|
||||
"default": "\"double\"",
|
||||
"value_type": "\"double\" | \"single\" | \"preserve\"",
|
||||
"scope": null,
|
||||
"example": "quote-style = \"single\"",
|
||||
"deprecated": null
|
||||
}
|
||||
});
|
||||
|
||||
let converted = RuffLspAdapter::convert_ruff_schema(&raw_schema);
|
||||
|
||||
assert!(converted.is_object());
|
||||
assert_eq!(
|
||||
converted.get("type").and_then(|v| v.as_str()),
|
||||
Some("object")
|
||||
);
|
||||
|
||||
let properties = converted
|
||||
.get("properties")
|
||||
.expect("should have properties")
|
||||
.as_object()
|
||||
.expect("properties should be an object");
|
||||
|
||||
assert!(properties.contains_key("line-length"));
|
||||
assert!(properties.contains_key("lint"));
|
||||
assert!(properties.contains_key("format"));
|
||||
|
||||
let line_length = properties
|
||||
.get("line-length")
|
||||
.expect("should have line-length")
|
||||
.as_object()
|
||||
.expect("line-length should be an object");
|
||||
|
||||
assert_eq!(
|
||||
line_length.get("type").and_then(|v| v.as_str()),
|
||||
Some("integer")
|
||||
);
|
||||
assert_eq!(
|
||||
line_length.get("default").and_then(|v| v.as_str()),
|
||||
Some("88")
|
||||
);
|
||||
|
||||
let lint = properties
|
||||
.get("lint")
|
||||
.expect("should have lint")
|
||||
.as_object()
|
||||
.expect("lint should be an object");
|
||||
|
||||
let lint_props = lint
|
||||
.get("properties")
|
||||
.expect("lint should have properties")
|
||||
.as_object()
|
||||
.expect("lint properties should be an object");
|
||||
|
||||
assert!(lint_props.contains_key("select"));
|
||||
assert!(lint_props.contains_key("isort"));
|
||||
|
||||
let select = lint_props.get("select").expect("should have select");
|
||||
assert_eq!(select.get("type").and_then(|v| v.as_str()), Some("array"));
|
||||
|
||||
let isort = lint_props
|
||||
.get("isort")
|
||||
.expect("should have isort")
|
||||
.as_object()
|
||||
.expect("isort should be an object");
|
||||
|
||||
let isort_props = isort
|
||||
.get("properties")
|
||||
.expect("isort should have properties")
|
||||
.as_object()
|
||||
.expect("isort properties should be an object");
|
||||
|
||||
let case_sensitive = isort_props
|
||||
.get("case-sensitive")
|
||||
.expect("should have case-sensitive");
|
||||
|
||||
assert_eq!(
|
||||
case_sensitive.get("type").and_then(|v| v.as_str()),
|
||||
Some("boolean")
|
||||
);
|
||||
assert!(case_sensitive.get("markdownDescription").is_some());
|
||||
|
||||
let format = properties
|
||||
.get("format")
|
||||
.expect("should have format")
|
||||
.as_object()
|
||||
.expect("format should be an object");
|
||||
|
||||
let format_props = format
|
||||
.get("properties")
|
||||
.expect("format should have properties")
|
||||
.as_object()
|
||||
.expect("format properties should be an object");
|
||||
|
||||
let quote_style = format_props
|
||||
.get("quote-style")
|
||||
.expect("should have quote-style");
|
||||
|
||||
assert_eq!(
|
||||
quote_style.get("type").and_then(|v| v.as_str()),
|
||||
Some("string")
|
||||
);
|
||||
|
||||
let enum_values = quote_style
|
||||
.get("enum")
|
||||
.expect("should have enum")
|
||||
.as_array()
|
||||
.expect("enum should be an array");
|
||||
|
||||
assert_eq!(enum_values.len(), 3);
|
||||
assert!(enum_values.contains(&serde_json::json!("double")));
|
||||
assert!(enum_values.contains(&serde_json::json!("single")));
|
||||
assert!(enum_values.contains(&serde_json::json!("preserve")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ use smol::fs::{self};
|
||||
use std::cmp::Reverse;
|
||||
use std::fmt::Display;
|
||||
use std::ops::Range;
|
||||
use std::process::Stdio;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
path::{Path, PathBuf},
|
||||
@@ -66,6 +67,68 @@ enum LibcType {
|
||||
}
|
||||
|
||||
impl RustLspAdapter {
|
||||
fn convert_rust_analyzer_schema(raw_schema: &serde_json::Value) -> serde_json::Value {
|
||||
let Some(schema_array) = raw_schema.as_array() else {
|
||||
return raw_schema.clone();
|
||||
};
|
||||
|
||||
let mut root_properties = serde_json::Map::new();
|
||||
|
||||
for item in schema_array {
|
||||
if let Some(props) = item.get("properties").and_then(|p| p.as_object()) {
|
||||
for (key, value) in props {
|
||||
let parts: Vec<&str> = key.split('.').collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parts_to_process = if parts.first() == Some(&"rust-analyzer") {
|
||||
&parts[1..]
|
||||
} else {
|
||||
&parts[..]
|
||||
};
|
||||
|
||||
if parts_to_process.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut current = &mut root_properties;
|
||||
|
||||
for (i, part) in parts_to_process.iter().enumerate() {
|
||||
let is_last = i == parts_to_process.len() - 1;
|
||||
|
||||
if is_last {
|
||||
current.insert(part.to_string(), value.clone());
|
||||
} else {
|
||||
let next_current = current
|
||||
.entry(part.to_string())
|
||||
.or_insert_with(|| {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})
|
||||
})
|
||||
.as_object_mut()
|
||||
.expect("should be an object")
|
||||
.entry("properties")
|
||||
.or_insert_with(|| serde_json::json!({}))
|
||||
.as_object_mut()
|
||||
.expect("properties should be an object");
|
||||
|
||||
current = next_current;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": root_properties
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
async fn determine_libc_type() -> LibcType {
|
||||
use futures::pin_mut;
|
||||
@@ -448,6 +511,37 @@ impl LspAdapter for RustLspAdapter {
|
||||
Some(label)
|
||||
}
|
||||
|
||||
async fn initialization_options_schema(
|
||||
self: Arc<Self>,
|
||||
language_server_binary: &LanguageServerBinary,
|
||||
) -> Option<serde_json::Value> {
|
||||
let mut command = util::command::new_smol_command(&language_server_binary.path);
|
||||
command
|
||||
.arg("--print-config-schema")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
let cmd = command
|
||||
.spawn()
|
||||
.map_err(|e| log::debug!("failed to spawn command {command:?}: {e}"))
|
||||
.ok()?;
|
||||
let output = cmd
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| log::debug!("failed to execute command {command:?}: {e}"))
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let raw_schema: serde_json::Value = serde_json::from_slice(output.stdout.as_slice())
|
||||
.map_err(|e| log::debug!("failed to parse rust-analyzer's JSON schema output: {e}"))
|
||||
.ok()?;
|
||||
|
||||
// Convert rust-analyzer's array-based schema format to nested JSON Schema
|
||||
let converted_schema = Self::convert_rust_analyzer_schema(&raw_schema);
|
||||
Some(converted_schema)
|
||||
}
|
||||
|
||||
async fn label_for_symbol(
|
||||
&self,
|
||||
name: &str,
|
||||
@@ -1912,4 +2006,90 @@ mod tests {
|
||||
);
|
||||
check([], "/project/src/main.rs", "--");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_rust_analyzer_schema() {
|
||||
let raw_schema = serde_json::json!([
|
||||
{
|
||||
"title": "Assist",
|
||||
"properties": {
|
||||
"rust-analyzer.assist.emitMustUse": {
|
||||
"markdownDescription": "Insert #[must_use] when generating `as_` methods for enum variants.",
|
||||
"default": false,
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Assist",
|
||||
"properties": {
|
||||
"rust-analyzer.assist.expressionFillDefault": {
|
||||
"markdownDescription": "Placeholder expression to use for missing expressions in assists.",
|
||||
"default": "todo",
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Cache Priming",
|
||||
"properties": {
|
||||
"rust-analyzer.cachePriming.enable": {
|
||||
"markdownDescription": "Warm up caches on project load.",
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
]);
|
||||
|
||||
let converted = RustLspAdapter::convert_rust_analyzer_schema(&raw_schema);
|
||||
|
||||
assert_eq!(
|
||||
converted.get("type").and_then(|v| v.as_str()),
|
||||
Some("object")
|
||||
);
|
||||
|
||||
let properties = converted
|
||||
.pointer("/properties")
|
||||
.expect("should have properties")
|
||||
.as_object()
|
||||
.expect("properties should be object");
|
||||
|
||||
assert!(properties.contains_key("assist"));
|
||||
assert!(properties.contains_key("cachePriming"));
|
||||
assert!(!properties.contains_key("rust-analyzer"));
|
||||
|
||||
let assist_props = properties
|
||||
.get("assist")
|
||||
.expect("should have assist")
|
||||
.pointer("/properties")
|
||||
.expect("assist should have properties")
|
||||
.as_object()
|
||||
.expect("assist properties should be object");
|
||||
|
||||
assert!(assist_props.contains_key("emitMustUse"));
|
||||
assert!(assist_props.contains_key("expressionFillDefault"));
|
||||
|
||||
let emit_must_use = assist_props
|
||||
.get("emitMustUse")
|
||||
.expect("should have emitMustUse");
|
||||
assert_eq!(
|
||||
emit_must_use.get("type").and_then(|v| v.as_str()),
|
||||
Some("boolean")
|
||||
);
|
||||
assert_eq!(
|
||||
emit_must_use.get("default").and_then(|v| v.as_bool()),
|
||||
Some(false)
|
||||
);
|
||||
|
||||
let cache_priming_props = properties
|
||||
.get("cachePriming")
|
||||
.expect("should have cachePriming")
|
||||
.pointer("/properties")
|
||||
.expect("cachePriming should have properties")
|
||||
.as_object()
|
||||
.expect("cachePriming properties should be object");
|
||||
|
||||
assert!(cache_priming_props.contains_key("enable"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,6 +345,7 @@ impl LspAdapter for VtslsLspAdapter {
|
||||
let lsp_settings = content
|
||||
.project
|
||||
.lsp
|
||||
.0
|
||||
.entry(VTSLS_SERVER_NAME.into())
|
||||
.or_default();
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use language::LanguageName;
|
||||
use log::Level;
|
||||
pub use path_range::{LineCol, PathWithRange};
|
||||
use ui::Checkbox;
|
||||
use ui::CopyButton;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::iter;
|
||||
@@ -22,9 +23,9 @@ use collections::{HashMap, HashSet};
|
||||
use gpui::{
|
||||
AnyElement, App, BorderStyle, Bounds, ClipboardItem, CursorStyle, DispatchPhase, Edges, Entity,
|
||||
FocusHandle, Focusable, FontStyle, FontWeight, GlobalElementId, Hitbox, Hsla, Image,
|
||||
ImageFormat, KeyContext, Length, MouseDownEvent, MouseEvent, MouseMoveEvent, MouseUpEvent,
|
||||
Point, ScrollHandle, Stateful, StrikethroughStyle, StyleRefinement, StyledText, Task,
|
||||
TextLayout, TextRun, TextStyle, TextStyleRefinement, actions, img, point, quad,
|
||||
ImageFormat, KeyContext, Length, MouseButton, MouseDownEvent, MouseEvent, MouseMoveEvent,
|
||||
MouseUpEvent, Point, ScrollHandle, Stateful, StrikethroughStyle, StyleRefinement, StyledText,
|
||||
Task, TextLayout, TextRun, TextStyle, TextStyleRefinement, actions, img, point, quad,
|
||||
};
|
||||
use language::{Language, LanguageRegistry, Rope};
|
||||
use parser::CodeBlockMetadata;
|
||||
@@ -32,7 +33,7 @@ use parser::{MarkdownEvent, MarkdownTag, MarkdownTagEnd, parse_links_only, parse
|
||||
use pulldown_cmark::Alignment;
|
||||
use sum_tree::TreeMap;
|
||||
use theme::SyntaxTheme;
|
||||
use ui::{ScrollAxes, Scrollbars, Tooltip, WithScrollbar, prelude::*};
|
||||
use ui::{ScrollAxes, Scrollbars, WithScrollbar, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::parser::CodeBlockKind;
|
||||
@@ -112,6 +113,7 @@ pub struct Markdown {
|
||||
options: Options,
|
||||
copied_code_blocks: HashSet<ElementId>,
|
||||
code_block_scroll_handles: HashMap<usize, ScrollHandle>,
|
||||
context_menu_selected_text: Option<String>,
|
||||
}
|
||||
|
||||
struct Options {
|
||||
@@ -181,6 +183,7 @@ impl Markdown {
|
||||
},
|
||||
copied_code_blocks: HashSet::default(),
|
||||
code_block_scroll_handles: HashMap::default(),
|
||||
context_menu_selected_text: None,
|
||||
};
|
||||
this.parse(cx);
|
||||
this
|
||||
@@ -205,6 +208,7 @@ impl Markdown {
|
||||
},
|
||||
copied_code_blocks: HashSet::default(),
|
||||
code_block_scroll_handles: HashMap::default(),
|
||||
context_menu_selected_text: None,
|
||||
};
|
||||
this.parse(cx);
|
||||
this
|
||||
@@ -289,6 +293,14 @@ impl Markdown {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn selected_text(&self) -> Option<String> {
|
||||
if self.selection.end <= self.selection.start {
|
||||
None
|
||||
} else {
|
||||
Some(self.source[self.selection.start..self.selection.end].to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn copy(&self, text: &RenderedText, _: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.selection.end <= self.selection.start {
|
||||
return;
|
||||
@@ -297,7 +309,11 @@ impl Markdown {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(text));
|
||||
}
|
||||
|
||||
fn copy_as_markdown(&self, _: &mut Window, cx: &mut Context<Self>) {
|
||||
fn copy_as_markdown(&mut self, _: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some(text) = self.context_menu_selected_text.take() {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(text));
|
||||
return;
|
||||
}
|
||||
if self.selection.end <= self.selection.start {
|
||||
return;
|
||||
}
|
||||
@@ -305,6 +321,10 @@ impl Markdown {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(text));
|
||||
}
|
||||
|
||||
fn capture_selection_for_context_menu(&mut self) {
|
||||
self.context_menu_selected_text = self.selected_text();
|
||||
}
|
||||
|
||||
fn parse(&mut self, cx: &mut Context<Self>) {
|
||||
if self.source.is_empty() {
|
||||
return;
|
||||
@@ -665,6 +685,19 @@ impl MarkdownElement {
|
||||
|
||||
let on_open_url = self.on_url_click.take();
|
||||
|
||||
self.on_mouse_event(window, cx, {
|
||||
let hitbox = hitbox.clone();
|
||||
move |markdown, event: &MouseDownEvent, phase, window, _| {
|
||||
if phase.capture()
|
||||
&& event.button == MouseButton::Right
|
||||
&& hitbox.is_hovered(window)
|
||||
{
|
||||
// Capture selected text so it survives until menu item is clicked
|
||||
markdown.capture_selection_for_context_menu();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
self.on_mouse_event(window, cx, {
|
||||
let rendered_text = rendered_text.clone();
|
||||
let hitbox = hitbox.clone();
|
||||
@@ -713,7 +746,7 @@ impl MarkdownElement {
|
||||
window.prevent_default();
|
||||
cx.notify();
|
||||
}
|
||||
} else if phase.capture() {
|
||||
} else if phase.capture() && event.button == MouseButton::Left {
|
||||
markdown.selection = Selection::default();
|
||||
markdown.pressed_link = None;
|
||||
cx.notify();
|
||||
@@ -1170,7 +1203,6 @@ impl Element for MarkdownElement {
|
||||
range.end,
|
||||
code,
|
||||
self.markdown.clone(),
|
||||
cx,
|
||||
);
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -1201,7 +1233,6 @@ impl Element for MarkdownElement {
|
||||
range.end,
|
||||
code,
|
||||
self.markdown.clone(),
|
||||
cx,
|
||||
);
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -1417,26 +1448,12 @@ fn render_copy_code_block_button(
|
||||
id: usize,
|
||||
code: String,
|
||||
markdown: Entity<Markdown>,
|
||||
cx: &App,
|
||||
) -> impl IntoElement {
|
||||
let id = ElementId::named_usize("copy-markdown-code", id);
|
||||
let was_copied = markdown.read(cx).copied_code_blocks.contains(&id);
|
||||
IconButton::new(
|
||||
id.clone(),
|
||||
if was_copied {
|
||||
IconName::Check
|
||||
} else {
|
||||
IconName::Copy
|
||||
},
|
||||
)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_size(IconSize::Small)
|
||||
.style(ButtonStyle::Filled)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.tooltip(Tooltip::text("Copy"))
|
||||
.on_click({
|
||||
|
||||
CopyButton::new(code.clone()).custom_on_click({
|
||||
let markdown = markdown;
|
||||
move |_event, _window, cx| {
|
||||
move |_window, cx| {
|
||||
let id = id.clone();
|
||||
markdown.update(cx, |this, cx| {
|
||||
this.copied_code_blocks.insert(id.clone());
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user