Compare commits
34 Commits
inline-ass
...
new-provid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc5f5e35e4 | ||
|
|
7183b8a1cd | ||
|
|
b1934fb712 | ||
|
|
a198b6c0d1 | ||
|
|
8b5b2712c8 | ||
|
|
4464392e8e | ||
|
|
a0d3bc31e9 | ||
|
|
ccd6672d1a | ||
|
|
21de6d35dd | ||
|
|
2031ca17e5 | ||
|
|
8b1ce75a57 | ||
|
|
5559726fd7 | ||
|
|
e1a9269921 | ||
|
|
3b6b3ff504 | ||
|
|
aabed94970 | ||
|
|
2d3a3521ba | ||
|
|
a48bd10da0 | ||
|
|
fec9525be4 | ||
|
|
bf2b8e999e | ||
|
|
63c35d2b00 | ||
|
|
1396c68010 | ||
|
|
fcb3d3dec6 | ||
|
|
f54e7f8c9d | ||
|
|
2a89529d7f | ||
|
|
58207325e2 | ||
|
|
e08ab99e8d | ||
|
|
a95f3f33a4 | ||
|
|
b0767c1b1f | ||
|
|
b200e10bc4 | ||
|
|
948905d916 | ||
|
|
04de456373 | ||
|
|
e5ce32e936 | ||
|
|
d7caae30de | ||
|
|
c7e77674a1 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -39,6 +39,3 @@ xcuserdata/
|
||||
# Don't commit any secrets to the repo.
|
||||
.env
|
||||
.env.secret.toml
|
||||
|
||||
# `nix build` output
|
||||
/result
|
||||
|
||||
1434
Cargo.lock
generated
1434
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
48
Cargo.toml
48
Cargo.toml
@@ -54,9 +54,9 @@ members = [
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
"crates/edit_prediction",
|
||||
"crates/edit_prediction_types",
|
||||
"crates/edit_prediction_ui",
|
||||
"crates/edit_prediction_button",
|
||||
"crates/edit_prediction_context",
|
||||
"crates/zeta2_tools",
|
||||
"crates/editor",
|
||||
"crates/eval",
|
||||
"crates/eval_utils",
|
||||
@@ -201,11 +201,10 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zed_env_vars",
|
||||
"crates/edit_prediction_cli",
|
||||
"crates/zeta",
|
||||
"crates/zeta_cli",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
"crates/ztracing",
|
||||
"crates/ztracing_macro",
|
||||
|
||||
#
|
||||
# Extensions
|
||||
@@ -244,6 +243,7 @@ activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent_ui = { path = "crates/agent_ui" }
|
||||
agent_settings = { path = "crates/agent_settings" }
|
||||
agent_servers = { path = "crates/agent_servers" }
|
||||
ai = { path = "crates/ai" }
|
||||
ai_onboarding = { path = "crates/ai_onboarding" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
askpass = { path = "crates/askpass" }
|
||||
@@ -253,6 +253,7 @@ assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
|
||||
audio = { path = "crates/audio" }
|
||||
auto_update = { path = "crates/auto_update" }
|
||||
auto_update_helper = { path = "crates/auto_update_helper" }
|
||||
auto_update_ui = { path = "crates/auto_update_ui" }
|
||||
aws_http_client = { path = "crates/aws_http_client" }
|
||||
bedrock = { path = "crates/bedrock" }
|
||||
@@ -267,6 +268,7 @@ cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
|
||||
collab = { path = "crates/collab" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections", version = "0.1.0" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
@@ -311,9 +313,10 @@ http_client = { path = "crates/http_client" }
|
||||
http_client_tls = { path = "crates/http_client_tls" }
|
||||
icons = { path = "crates/icons" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
edit_prediction_types = { path = "crates/edit_prediction_types" }
|
||||
edit_prediction_ui = { path = "crates/edit_prediction_ui" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
edit_prediction_button = { path = "crates/edit_prediction_button" }
|
||||
edit_prediction_context = { path = "crates/edit_prediction_context" }
|
||||
zeta2_tools = { path = "crates/zeta2_tools" }
|
||||
inspector_ui = { path = "crates/inspector_ui" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
journal = { path = "crates/journal" }
|
||||
@@ -355,6 +358,8 @@ panel = { path = "crates/panel" }
|
||||
paths = { path = "crates/paths" }
|
||||
perf = { path = "tooling/perf" }
|
||||
picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
plugin_macros = { path = "crates/plugin_macros" }
|
||||
prettier = { path = "crates/prettier" }
|
||||
settings_profile_selector = { path = "crates/settings_profile_selector" }
|
||||
project = { path = "crates/project" }
|
||||
@@ -365,10 +370,12 @@ proto = { path = "crates/proto" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
refineable = { path = "crates/refineable" }
|
||||
release_channel = { path = "crates/release_channel" }
|
||||
scheduler = { path = "crates/scheduler" }
|
||||
remote = { path = "crates/remote" }
|
||||
remote_server = { path = "crates/remote_server" }
|
||||
repl = { path = "crates/repl" }
|
||||
reqwest_client = { path = "crates/reqwest_client" }
|
||||
rich_text = { path = "crates/rich_text" }
|
||||
rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] }
|
||||
rope = { path = "crates/rope" }
|
||||
rpc = { path = "crates/rpc" }
|
||||
@@ -385,6 +392,7 @@ snippets_ui = { path = "crates/snippets_ui" }
|
||||
sqlez = { path = "crates/sqlez" }
|
||||
sqlez_macros = { path = "crates/sqlez_macros" }
|
||||
story = { path = "crates/story" }
|
||||
storybook = { path = "crates/storybook" }
|
||||
streaming_diff = { path = "crates/streaming_diff" }
|
||||
sum_tree = { path = "crates/sum_tree" }
|
||||
supermaven = { path = "crates/supermaven" }
|
||||
@@ -401,6 +409,7 @@ terminal_view = { path = "crates/terminal_view" }
|
||||
text = { path = "crates/text" }
|
||||
theme = { path = "crates/theme" }
|
||||
theme_extension = { path = "crates/theme_extension" }
|
||||
theme_importer = { path = "crates/theme_importer" }
|
||||
theme_selector = { path = "crates/theme_selector" }
|
||||
time_format = { path = "crates/time_format" }
|
||||
title_bar = { path = "crates/title_bar" }
|
||||
@@ -424,17 +433,15 @@ x_ai = { path = "crates/x_ai" }
|
||||
zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zeta = { path = "crates/zeta" }
|
||||
zlog = { path = "crates/zlog" }
|
||||
zlog_settings = { path = "crates/zlog_settings" }
|
||||
ztracing = { path = "crates/ztracing" }
|
||||
ztracing_macro = { path = "crates/ztracing_macro" }
|
||||
|
||||
#
|
||||
# External crates
|
||||
#
|
||||
|
||||
agent-client-protocol = { version = "=0.9.0", features = ["unstable"] }
|
||||
agent-client-protocol = { version = "=0.8.0", features = ["unstable"] }
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = "0.25.1-rc1"
|
||||
any_vec = "0.14"
|
||||
@@ -501,11 +508,13 @@ exec = "0.3.1"
|
||||
fancy-regex = "0.16.0"
|
||||
fork = "0.4.0"
|
||||
futures = "0.3"
|
||||
futures-batch = "0.6.1"
|
||||
futures-lite = "1.13"
|
||||
gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" }
|
||||
git2 = { version = "0.20.1", default-features = false }
|
||||
globset = "0.4"
|
||||
handlebars = "4.3"
|
||||
hashbrown = "0.15.3"
|
||||
heck = "0.5"
|
||||
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
|
||||
hex = "0.4.3"
|
||||
@@ -541,6 +550,7 @@ nanoid = "0.4"
|
||||
nbformat = "0.15.0"
|
||||
nix = "0.29"
|
||||
num-format = "0.4.4"
|
||||
num-traits = "0.2"
|
||||
objc = "0.2"
|
||||
objc2-foundation = { version = "=0.3.1", default-features = false, features = [
|
||||
"NSArray",
|
||||
@@ -579,6 +589,7 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev =
|
||||
pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
|
||||
@@ -618,6 +629,7 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197
|
||||
schemars = { version = "1.0", features = ["indexmap2"] }
|
||||
semver = { version = "1.0", features = ["serde"] }
|
||||
serde = { version = "1.0.221", features = ["derive", "rc"] }
|
||||
serde_derive = "1.0.221"
|
||||
serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] }
|
||||
serde_json_lenient = { version = "0.2", features = [
|
||||
"preserve_order",
|
||||
@@ -629,6 +641,7 @@ serde_urlencoded = "0.7"
|
||||
sha2 = "0.10"
|
||||
shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
similar = "2.6"
|
||||
simplelog = "0.12.2"
|
||||
slotmap = "1.0.6"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
@@ -683,7 +696,6 @@ tree-sitter-ruby = "0.23"
|
||||
tree-sitter-rust = "0.24"
|
||||
tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347
|
||||
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" }
|
||||
tracing = "0.1.40"
|
||||
unicase = "2.6"
|
||||
unicode-script = "0.5.7"
|
||||
unicode-segmentation = "1.10"
|
||||
@@ -707,6 +719,7 @@ wasmtime-wasi = "29"
|
||||
wax = "0.6"
|
||||
which = "6.0.0"
|
||||
windows-core = "0.61"
|
||||
wit-component = "0.221"
|
||||
yawc = "0.2.5"
|
||||
zeroize = "1.8"
|
||||
zstd = "0.11"
|
||||
@@ -788,13 +801,20 @@ settings_macros = { opt-level = 3 }
|
||||
sqlez_macros = { opt-level = 3, codegen-units = 1 }
|
||||
ui_macros = { opt-level = 3 }
|
||||
util_macros = { opt-level = 3 }
|
||||
serde_derive = { opt-level = 3 }
|
||||
quote = { opt-level = 3 }
|
||||
syn = { opt-level = 3 }
|
||||
proc-macro2 = { opt-level = 3 }
|
||||
# proc-macros end
|
||||
|
||||
taffy = { opt-level = 3 }
|
||||
cranelift-codegen = { opt-level = 3 }
|
||||
cranelift-codegen-meta = { opt-level = 3 }
|
||||
cranelift-codegen-shared = { opt-level = 3 }
|
||||
resvg = { opt-level = 3 }
|
||||
rustybuzz = { opt-level = 3 }
|
||||
ttf-parser = { opt-level = 3 }
|
||||
wasmtime-cranelift = { opt-level = 3 }
|
||||
wasmtime = { opt-level = 3 }
|
||||
# Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster
|
||||
activity_indicator = { codegen-units = 1 }
|
||||
@@ -803,11 +823,12 @@ breadcrumbs = { codegen-units = 1 }
|
||||
collections = { codegen-units = 1 }
|
||||
command_palette = { codegen-units = 1 }
|
||||
command_palette_hooks = { codegen-units = 1 }
|
||||
extension_cli = { codegen-units = 1 }
|
||||
feature_flags = { codegen-units = 1 }
|
||||
file_icons = { codegen-units = 1 }
|
||||
fsevent = { codegen-units = 1 }
|
||||
image_viewer = { codegen-units = 1 }
|
||||
edit_prediction_ui = { codegen-units = 1 }
|
||||
edit_prediction_button = { codegen-units = 1 }
|
||||
install_cli = { codegen-units = 1 }
|
||||
journal = { codegen-units = 1 }
|
||||
json_schema_store = { codegen-units = 1 }
|
||||
@@ -822,6 +843,7 @@ project_symbols = { codegen-units = 1 }
|
||||
refineable = { codegen-units = 1 }
|
||||
release_channel = { codegen-units = 1 }
|
||||
reqwest_client = { codegen-units = 1 }
|
||||
rich_text = { codegen-units = 1 }
|
||||
session = { codegen-units = 1 }
|
||||
snippet = { codegen-units = 1 }
|
||||
snippets_ui = { codegen-units = 1 }
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4 2V10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12 6C12.5304 6 13.0391 5.78929 13.4142 5.41421C13.7893 5.03914 14 4.53043 14 4C14 3.46957 13.7893 2.96086 13.4142 2.58579C13.0391 2.21071 12.5304 2 12 2C11.4696 2 10.9609 2.21071 10.5858 2.58579C10.2107 2.96086 10 3.46957 10 4C10 4.53043 10.2107 5.03914 10.5858 5.41421C10.9609 5.78929 11.4696 6 12 6Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4 14C4.53043 14 5.03914 13.7893 5.41421 13.4142C5.78929 13.0391 6 12.5304 6 12C6 11.4696 5.78929 10.9609 5.41421 10.5858C5.03914 10.2107 4.53043 10 4 10C3.46957 10 2.96086 10.2107 2.58579 10.5858C2.21071 10.9609 2 11.4696 2 12C2 12.5304 2.21071 13.0391 2.58579 13.4142C2.96086 13.7893 3.46957 14 4 14Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10 4C8.4087 4 6.88258 4.63214 5.75736 5.75736C4.63214 6.88258 4 8.4087 4 10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12 10V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M14 12H10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.4 KiB |
@@ -1,11 +0,0 @@
|
||||
<svg width="28" height="28" viewBox="0 0 28 28" fill="none" id="svg1378540956_510">
|
||||
<g clip-path="url(#svg1378540956_510_clip0_1_1506)" transform="translate(4, 4) scale(0.857)">
|
||||
<path d="M17.0547 0.372066H8.52652L-0.00165176 8.90024V17.4284H8.52652V8.90024H17.0547V0.372066Z" fill="#1A1C20"></path>
|
||||
<path d="M10.1992 27.6279H18.7274L27.2556 19.0998V10.5716H18.7274V19.0998H10.1992V27.6279Z" fill="#1A1C20"></path>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="svg1378540956_510_clip0_1_1506">
|
||||
<rect width="27.2559" height="27.2559" fill="white" transform="translate(0 0.37207)"></rect>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 593 B |
@@ -41,7 +41,7 @@
|
||||
"ctrl-f11": "debugger::StepInto",
|
||||
"shift-f11": "debugger::StepOut",
|
||||
"f11": "zed::ToggleFullScreen",
|
||||
"ctrl-alt-z": "edit_prediction::RatePredictions",
|
||||
"ctrl-alt-z": "edit_prediction::RateCompletions",
|
||||
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
|
||||
"ctrl-alt-l": "lsp_tool::ToggleMenu"
|
||||
}
|
||||
@@ -616,8 +616,8 @@
|
||||
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
|
||||
"ctrl-t": "project_symbols::Toggle",
|
||||
"ctrl-p": "file_finder::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-e": "file_finder::Toggle",
|
||||
"f1": "command_palette::Toggle",
|
||||
"ctrl-shift-p": "command_palette::Toggle",
|
||||
@@ -1322,18 +1322,25 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"ctrl-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"ctrl-shift-i": "branch_picker::FilterRemotes"
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"cmd-m": "zed::Minimize",
|
||||
"fn-f": "zed::ToggleFullScreen",
|
||||
"ctrl-cmd-f": "zed::ToggleFullScreen",
|
||||
"ctrl-cmd-z": "edit_prediction::RatePredictions",
|
||||
"ctrl-cmd-z": "edit_prediction::RateCompletions",
|
||||
"ctrl-cmd-i": "edit_prediction::ToggleMenu",
|
||||
"ctrl-cmd-l": "lsp_tool::ToggleMenu",
|
||||
"ctrl-cmd-c": "editor::DisplayCursorNames"
|
||||
@@ -684,8 +684,8 @@
|
||||
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
|
||||
"cmd-t": "project_symbols::Toggle",
|
||||
"cmd-p": "file_finder::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"cmd-shift-p": "command_palette::Toggle",
|
||||
"cmd-shift-m": "diagnostics::Deploy",
|
||||
"cmd-shift-e": "project_panel::ToggleFocus",
|
||||
@@ -1427,18 +1427,25 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
"enter": "editor::Newline",
|
||||
"cmd-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"cmd-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"cmd-shift-i": "branch_picker::FilterRemotes"
|
||||
"cmd-shift-backspace": "branch_picker::DeleteBranch"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -24,8 +24,7 @@
|
||||
"ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
|
||||
"ctrl-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"ctrl-o": "workspace::OpenFiles",
|
||||
"ctrl-k ctrl-o": "workspace::Open",
|
||||
"ctrl-o": "workspace::Open",
|
||||
"ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl-shift-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
|
||||
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
|
||||
@@ -609,8 +608,8 @@
|
||||
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
|
||||
"ctrl-t": "project_symbols::Toggle",
|
||||
"ctrl-p": "file_finder::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-e": "file_finder::Toggle",
|
||||
"f1": "command_palette::Toggle",
|
||||
"ctrl-shift-p": "command_palette::Toggle",
|
||||
@@ -1129,8 +1128,6 @@
|
||||
"ctrl-e": ["terminal::SendKeystroke", "ctrl-e"],
|
||||
"ctrl-o": ["terminal::SendKeystroke", "ctrl-o"],
|
||||
"ctrl-w": ["terminal::SendKeystroke", "ctrl-w"],
|
||||
"ctrl-q": ["terminal::SendKeystroke", "ctrl-q"],
|
||||
"ctrl-r": ["terminal::SendKeystroke", "ctrl-r"],
|
||||
"ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"],
|
||||
"ctrl-shift-a": "editor::SelectAll",
|
||||
"ctrl-shift-f": "buffer_search::Deploy",
|
||||
@@ -1344,18 +1341,25 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"ctrl-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"ctrl-shift-i": "branch_picker::FilterRemotes"
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
{{#if language_name}}
|
||||
Here's a file of {{language_name}} that the user is going to ask you to make an edit to.
|
||||
{{else}}
|
||||
Here's a file of text that the user is going to ask you to make an edit to.
|
||||
{{/if}}
|
||||
|
||||
The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
|
||||
|
||||
<document>
|
||||
{{{document_content}}}
|
||||
</document>
|
||||
|
||||
{{#if is_truncated}}
|
||||
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
|
||||
{{/if}}
|
||||
|
||||
{{#if rewrite_section}}
|
||||
And here's the section to rewrite based on that prompt again for reference:
|
||||
|
||||
<rewrite_this>
|
||||
{{{rewrite_section}}}
|
||||
</rewrite_this>
|
||||
|
||||
{{#if diagnostic_errors}}
|
||||
Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to.
|
||||
|
||||
{{#each diagnostic_errors}}
|
||||
<diagnostic_error>
|
||||
<line_number>{{line_number}}</line_number>
|
||||
<error_message>{{error_message}}</error_message>
|
||||
<code_content>{{code_content}}</code_content>
|
||||
</diagnostic_error>
|
||||
{{/each}}
|
||||
{{/if}}
|
||||
|
||||
{{/if}}
|
||||
|
||||
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
|
||||
|
||||
Start at the indentation level in the original file in the rewritten {{content_type}}.
|
||||
|
||||
You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if
|
||||
you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so.
|
||||
It is an error if you try to make a change that cannot be made simply by editing the rewrite_section.
|
||||
@@ -12,7 +12,7 @@
|
||||
"theme": {
|
||||
"mode": "system",
|
||||
"light": "One Light",
|
||||
"dark": "One Dark"
|
||||
"dark": "One Dark",
|
||||
},
|
||||
"icon_theme": "Zed (Default)",
|
||||
// The name of a base set of key bindings to use.
|
||||
@@ -29,7 +29,7 @@
|
||||
// Features that can be globally enabled or disabled
|
||||
"features": {
|
||||
// Which edit prediction provider to use.
|
||||
"edit_prediction_provider": "zed"
|
||||
"edit_prediction_provider": "zed",
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
// ".ZedMono" currently aliases to Lilex
|
||||
@@ -69,7 +69,7 @@
|
||||
// The OpenType features to enable for text in the UI
|
||||
"ui_font_features": {
|
||||
// Disable ligatures:
|
||||
"calt": false
|
||||
"calt": false,
|
||||
},
|
||||
// The weight of the UI font in standard CSS units from 100 to 900.
|
||||
"ui_font_weight": 400,
|
||||
@@ -87,7 +87,7 @@
|
||||
"border_size": 0.0,
|
||||
// Opacity of the inactive panes. 0 means transparent, 1 means opaque.
|
||||
// Values are clamped to the [0.0, 1.0] range.
|
||||
"inactive_opacity": 1.0
|
||||
"inactive_opacity": 1.0,
|
||||
},
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
// choices: contained, full, left_aligned, right_aligned
|
||||
@@ -103,12 +103,12 @@
|
||||
"left_padding": 0.2,
|
||||
// The relative width of the right padding of the central pane from the
|
||||
// workspace when the centered layout is used.
|
||||
"right_padding": 0.2
|
||||
"right_padding": 0.2,
|
||||
},
|
||||
// Image viewer settings
|
||||
"image_viewer": {
|
||||
// The unit for image file sizes: "binary" (KiB, MiB) or decimal (KB, MB)
|
||||
"unit": "binary"
|
||||
"unit": "binary",
|
||||
},
|
||||
// Determines the modifier to be used to add multiple cursors with the mouse. The open hover link mouse gestures will adapt such that it do not conflict with the multicursor modifier.
|
||||
//
|
||||
@@ -296,7 +296,7 @@
|
||||
// When true, enables drag and drop text selection in buffer.
|
||||
"enabled": true,
|
||||
// The delay in milliseconds that must elapse before drag and drop is allowed. Otherwise, a new text selection is created.
|
||||
"delay": 300
|
||||
"delay": 300,
|
||||
},
|
||||
// What to do when go to definition yields no results.
|
||||
//
|
||||
@@ -400,14 +400,14 @@
|
||||
// Visible characters used to render whitespace when show_whitespaces is enabled.
|
||||
"whitespace_map": {
|
||||
"space": "•",
|
||||
"tab": "→"
|
||||
"tab": "→",
|
||||
},
|
||||
// Settings related to calls in Zed
|
||||
"calls": {
|
||||
// Join calls with the microphone live by default
|
||||
"mute_on_join": false,
|
||||
// Share your project when you are the first to join a channel
|
||||
"share_on_join": false
|
||||
"share_on_join": false,
|
||||
},
|
||||
// Toolbar related settings
|
||||
"toolbar": {
|
||||
@@ -420,7 +420,7 @@
|
||||
// Whether to show agent review buttons in the editor toolbar.
|
||||
"agent_review": true,
|
||||
// Whether to show code action buttons in the editor toolbar.
|
||||
"code_actions": false
|
||||
"code_actions": false,
|
||||
},
|
||||
// Whether to allow windows to tab together based on the user’s tabbing preference (macOS only).
|
||||
"use_system_window_tabs": false,
|
||||
@@ -439,7 +439,7 @@
|
||||
// Whether to show the sign in button in the titlebar.
|
||||
"show_sign_in": true,
|
||||
// Whether to show the menus in the titlebar.
|
||||
"show_menus": false
|
||||
"show_menus": false,
|
||||
},
|
||||
"audio": {
|
||||
// Opt into the new audio system.
|
||||
@@ -472,7 +472,7 @@
|
||||
// the future we will migrate by setting this to false
|
||||
//
|
||||
// You need to rejoin a call for this setting to apply
|
||||
"experimental.legacy_audio_compatible": true
|
||||
"experimental.legacy_audio_compatible": true,
|
||||
},
|
||||
// Scrollbar related settings
|
||||
"scrollbar": {
|
||||
@@ -511,8 +511,8 @@
|
||||
// When false, forcefully disables the horizontal scrollbar. Otherwise, obey other settings.
|
||||
"horizontal": true,
|
||||
// When false, forcefully disables the vertical scrollbar. Otherwise, obey other settings.
|
||||
"vertical": true
|
||||
}
|
||||
"vertical": true,
|
||||
},
|
||||
},
|
||||
// Minimap related settings
|
||||
"minimap": {
|
||||
@@ -560,7 +560,7 @@
|
||||
// 3. "gutter" or "none" to not highlight the current line in the minimap.
|
||||
"current_line_highlight": null,
|
||||
// Maximum number of columns to display in the minimap.
|
||||
"max_width_columns": 80
|
||||
"max_width_columns": 80,
|
||||
},
|
||||
// Enable middle-click paste on Linux.
|
||||
"middle_click_paste": true,
|
||||
@@ -583,7 +583,7 @@
|
||||
// Whether to show fold buttons in the gutter.
|
||||
"folds": true,
|
||||
// Minimum number of characters to reserve space for in the gutter.
|
||||
"min_line_number_digits": 4
|
||||
"min_line_number_digits": 4,
|
||||
},
|
||||
"indent_guides": {
|
||||
// Whether to show indent guides in the editor.
|
||||
@@ -604,7 +604,7 @@
|
||||
//
|
||||
// 1. "disabled"
|
||||
// 2. "indent_aware"
|
||||
"background_coloring": "disabled"
|
||||
"background_coloring": "disabled",
|
||||
},
|
||||
// Whether the editor will scroll beyond the last line.
|
||||
"scroll_beyond_last_line": "one_page",
|
||||
@@ -623,7 +623,7 @@
|
||||
"fast_scroll_sensitivity": 4.0,
|
||||
"sticky_scroll": {
|
||||
// Whether to stick scopes to the top of the editor.
|
||||
"enabled": false
|
||||
"enabled": false,
|
||||
},
|
||||
"relative_line_numbers": "disabled",
|
||||
// If 'search_wrap' is disabled, search result do not wrap around the end of the file.
|
||||
@@ -641,7 +641,7 @@
|
||||
// Whether to interpret the search query as a regular expression.
|
||||
"regex": false,
|
||||
// Whether to center the cursor on each search match when navigating.
|
||||
"center_on_match": false
|
||||
"center_on_match": false,
|
||||
},
|
||||
// When to populate a new search's query based on the text under the cursor.
|
||||
// This setting can take the following three values:
|
||||
@@ -684,8 +684,8 @@
|
||||
"shift": false,
|
||||
"alt": false,
|
||||
"platform": false,
|
||||
"function": false
|
||||
}
|
||||
"function": false,
|
||||
},
|
||||
},
|
||||
// Whether to resize all the panels in a dock when resizing the dock.
|
||||
// Can be a combination of "left", "right" and "bottom".
|
||||
@@ -733,7 +733,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Which files containing diagnostic errors/warnings to mark in the project panel.
|
||||
// This setting can take the following three values:
|
||||
@@ -756,7 +756,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always"
|
||||
"show": "always",
|
||||
},
|
||||
// Sort order for entries in the project panel.
|
||||
// This setting can take three values:
|
||||
@@ -781,8 +781,8 @@
|
||||
// Whether to automatically open files after pasting or duplicating them.
|
||||
"on_paste": true,
|
||||
// Whether to automatically open files dropped from external sources.
|
||||
"on_drop": true
|
||||
}
|
||||
"on_drop": true,
|
||||
},
|
||||
},
|
||||
"outline_panel": {
|
||||
// Whether to show the outline panel button in the status bar
|
||||
@@ -815,7 +815,7 @@
|
||||
// "always"
|
||||
// 2. Never show indent guides:
|
||||
// "never"
|
||||
"show": "always"
|
||||
"show": "always",
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -832,11 +832,11 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Default depth to expand outline items in the current file.
|
||||
// Set to 0 to collapse all items that have children, 1 or higher to collapse items at that depth or deeper.
|
||||
"expand_outlines_with_depth": 100
|
||||
"expand_outlines_with_depth": 100,
|
||||
},
|
||||
"collaboration_panel": {
|
||||
// Whether to show the collaboration panel button in the status bar.
|
||||
@@ -844,7 +844,7 @@
|
||||
// Where to dock the collaboration panel. Can be 'left' or 'right'.
|
||||
"dock": "left",
|
||||
// Default width of the collaboration panel.
|
||||
"default_width": 240
|
||||
"default_width": 240,
|
||||
},
|
||||
"git_panel": {
|
||||
// Whether to show the git panel button in the status bar.
|
||||
@@ -876,12 +876,12 @@
|
||||
// Choices: always, auto, never, system
|
||||
// Default: inherits editor scrollbar settings
|
||||
// "show": null
|
||||
}
|
||||
},
|
||||
},
|
||||
"message_editor": {
|
||||
// Whether to automatically replace emoji shortcodes with emoji characters.
|
||||
// For example: typing `:wave:` gets replaced with `👋`.
|
||||
"auto_replace_emoji_shortcode": true
|
||||
"auto_replace_emoji_shortcode": true,
|
||||
},
|
||||
"notification_panel": {
|
||||
// Whether to show the notification panel button in the status bar.
|
||||
@@ -889,7 +889,7 @@
|
||||
// Where to dock the notification panel. Can be 'left' or 'right'.
|
||||
"dock": "right",
|
||||
// Default width of the notification panel.
|
||||
"default_width": 380
|
||||
"default_width": 380,
|
||||
},
|
||||
"agent": {
|
||||
// Whether the agent is enabled.
|
||||
@@ -911,7 +911,7 @@
|
||||
// The provider to use.
|
||||
"provider": "zed.dev",
|
||||
// The model to use.
|
||||
"model": "claude-sonnet-4"
|
||||
"model": "claude-sonnet-4",
|
||||
},
|
||||
// Additional parameters for language model requests. When making a request to a model, parameters will be taken
|
||||
// from the last entry in this list that matches the model's provider and name. In each entry, both provider
|
||||
@@ -966,8 +966,8 @@
|
||||
"grep": true,
|
||||
"terminal": true,
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
"web_search": true,
|
||||
},
|
||||
},
|
||||
"ask": {
|
||||
"name": "Ask",
|
||||
@@ -984,14 +984,14 @@
|
||||
"open": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
"web_search": true,
|
||||
},
|
||||
},
|
||||
"minimal": {
|
||||
"name": "Minimal",
|
||||
"enable_all_context_servers": false,
|
||||
"tools": {}
|
||||
}
|
||||
"tools": {},
|
||||
},
|
||||
},
|
||||
// Where to show notifications when the agent has either completed
|
||||
// its response, or else needs confirmation before it can run a
|
||||
@@ -1020,7 +1020,7 @@
|
||||
// Minimum number of lines to display in the agent message editor.
|
||||
//
|
||||
// Default: 4
|
||||
"message_editor_min_lines": 4
|
||||
"message_editor_min_lines": 4,
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
"show_call_status_icon": true,
|
||||
@@ -1055,7 +1055,7 @@
|
||||
// Whether or not to show the navigation history buttons.
|
||||
"show_nav_history_buttons": true,
|
||||
// Whether or not to show the tab bar buttons.
|
||||
"show_tab_bar_buttons": true
|
||||
"show_tab_bar_buttons": true,
|
||||
},
|
||||
// Settings related to the editor's tabs
|
||||
"tabs": {
|
||||
@@ -1094,7 +1094,7 @@
|
||||
// "errors"
|
||||
// 3. Mark files with errors and warnings:
|
||||
// "all"
|
||||
"show_diagnostics": "off"
|
||||
"show_diagnostics": "off",
|
||||
},
|
||||
// Settings related to preview tabs.
|
||||
"preview_tabs": {
|
||||
@@ -1115,7 +1115,7 @@
|
||||
"enable_preview_file_from_code_navigation": true,
|
||||
// Whether to keep tabs in preview mode when code navigation is used to navigate away from them.
|
||||
// If `enable_preview_file_from_code_navigation` or `enable_preview_multibuffer_from_code_navigation` is also true, the new tab may replace the existing one.
|
||||
"enable_keep_preview_on_code_navigation": false
|
||||
"enable_keep_preview_on_code_navigation": false,
|
||||
},
|
||||
// Settings related to the file finder.
|
||||
"file_finder": {
|
||||
@@ -1159,7 +1159,7 @@
|
||||
// * "all": Use all gitignored files
|
||||
// * "indexed": Use only the files Zed had indexed
|
||||
// * "smart": Be smart and search for ignored when called from a gitignored worktree
|
||||
"include_ignored": "smart"
|
||||
"include_ignored": "smart",
|
||||
},
|
||||
// Whether or not to remove any trailing whitespace from lines of a buffer
|
||||
// before saving it.
|
||||
@@ -1230,7 +1230,7 @@
|
||||
// Send debug info like crash reports.
|
||||
"diagnostics": true,
|
||||
// Send anonymized usage data like what languages you're using Zed with.
|
||||
"metrics": true
|
||||
"metrics": true,
|
||||
},
|
||||
// Whether to disable all AI features in Zed.
|
||||
//
|
||||
@@ -1264,7 +1264,7 @@
|
||||
"enabled": true,
|
||||
// Minimum time to wait before pulling diagnostics from the language server(s).
|
||||
// 0 turns the debounce off.
|
||||
"debounce_ms": 50
|
||||
"debounce_ms": 50,
|
||||
},
|
||||
// Settings for inline diagnostics
|
||||
"inline": {
|
||||
@@ -1282,8 +1282,8 @@
|
||||
"min_column": 0,
|
||||
// The minimum severity of the diagnostics to show inline.
|
||||
// Inherits editor's diagnostics' max severity settings when `null`.
|
||||
"max_severity": null
|
||||
}
|
||||
"max_severity": null,
|
||||
},
|
||||
},
|
||||
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file
|
||||
// scans, file searches, and not be displayed in the project file tree. Takes precedence over `file_scan_inclusions`.
|
||||
@@ -1297,7 +1297,7 @@
|
||||
"**/.DS_Store",
|
||||
"**/Thumbs.db",
|
||||
"**/.classpath",
|
||||
"**/.settings"
|
||||
"**/.settings",
|
||||
],
|
||||
// Files or globs of files that will be included by Zed, even when ignored by git. This is useful
|
||||
// for files that are not tracked by git, but are still important to your project. Note that globs
|
||||
@@ -1332,14 +1332,14 @@
|
||||
// Whether or not to display the git commit summary on the same line.
|
||||
"show_commit_summary": false,
|
||||
// The minimum column number to show the inline blame information at
|
||||
"min_column": 0
|
||||
"min_column": 0,
|
||||
},
|
||||
"blame": {
|
||||
"show_avatar": true
|
||||
"show_avatar": true,
|
||||
},
|
||||
// Control which information is shown in the branch picker.
|
||||
"branch_picker": {
|
||||
"show_author_name": true
|
||||
"show_author_name": true,
|
||||
},
|
||||
// How git hunks are displayed visually in the editor.
|
||||
// This setting can take two values:
|
||||
@@ -1351,7 +1351,7 @@
|
||||
"hunk_style": "staged_hollow",
|
||||
// Should the name or path be displayed first in the git view.
|
||||
// "path_style": "file_name_first" or "file_path_first"
|
||||
"path_style": "file_name_first"
|
||||
"path_style": "file_name_first",
|
||||
},
|
||||
// The list of custom Git hosting providers.
|
||||
"git_hosting_providers": [
|
||||
@@ -1385,7 +1385,7 @@
|
||||
"**/secrets.yml",
|
||||
"**/.zed/settings.json", // zed project settings
|
||||
"/**/zed/settings.json", // zed user settings
|
||||
"/**/zed/keymap.json"
|
||||
"/**/zed/keymap.json",
|
||||
],
|
||||
// When to show edit predictions previews in buffer.
|
||||
// This setting takes two possible values:
|
||||
@@ -1403,15 +1403,15 @@
|
||||
"copilot": {
|
||||
"enterprise_uri": null,
|
||||
"proxy": null,
|
||||
"proxy_no_verify": null
|
||||
"proxy_no_verify": null,
|
||||
},
|
||||
"codestral": {
|
||||
"model": null,
|
||||
"max_tokens": null
|
||||
"max_tokens": null,
|
||||
},
|
||||
// Whether edit predictions are enabled when editing text threads in the agent panel.
|
||||
// This setting has no effect if globally disabled.
|
||||
"enabled_in_text_threads": true
|
||||
"enabled_in_text_threads": true,
|
||||
},
|
||||
// Settings specific to journaling
|
||||
"journal": {
|
||||
@@ -1421,7 +1421,7 @@
|
||||
// May take 2 values:
|
||||
// 1. hour12
|
||||
// 2. hour24
|
||||
"hour_format": "hour12"
|
||||
"hour_format": "hour12",
|
||||
},
|
||||
// Status bar-related settings.
|
||||
"status_bar": {
|
||||
@@ -1432,7 +1432,7 @@
|
||||
// Whether to show the cursor position button in the status bar.
|
||||
"cursor_position_button": true,
|
||||
// Whether to show active line endings button in the status bar.
|
||||
"line_endings_button": false
|
||||
"line_endings_button": false,
|
||||
},
|
||||
// Settings specific to the terminal
|
||||
"terminal": {
|
||||
@@ -1553,8 +1553,8 @@
|
||||
// Preferred Conda manager to use when activating Conda environments.
|
||||
// Values: "auto", "conda", "mamba", "micromamba"
|
||||
// Default: "auto"
|
||||
"conda_manager": "auto"
|
||||
}
|
||||
"conda_manager": "auto",
|
||||
},
|
||||
},
|
||||
"toolbar": {
|
||||
// Whether to display the terminal title in its toolbar's breadcrumbs.
|
||||
@@ -1562,7 +1562,7 @@
|
||||
//
|
||||
// The shell running in the terminal needs to be configured to emit the title.
|
||||
// Example: `echo -e "\e]2;New Title\007";`
|
||||
"breadcrumbs": false
|
||||
"breadcrumbs": false,
|
||||
},
|
||||
// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
@@ -1579,7 +1579,7 @@
|
||||
// "always"
|
||||
// 5. Never show the scrollbar:
|
||||
// "never"
|
||||
"show": null
|
||||
"show": null,
|
||||
},
|
||||
// Set the terminal's font size. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font size.
|
||||
@@ -1660,12 +1660,12 @@
|
||||
"# which may be followed by trailing punctuation",
|
||||
"[.,:)}\\]>]*",
|
||||
"# and always includes trailing whitespace or end of line",
|
||||
"([ ]+|$)"
|
||||
]
|
||||
"([ ]+|$)",
|
||||
],
|
||||
],
|
||||
// Timeout for hover and Cmd-click path hyperlink discovery in milliseconds. Specifying a
|
||||
// timeout of `0` will disable path hyperlinking in terminal.
|
||||
"path_hyperlink_timeout_ms": 1
|
||||
"path_hyperlink_timeout_ms": 1,
|
||||
},
|
||||
"code_actions_on_format": {},
|
||||
// Settings related to running tasks.
|
||||
@@ -1681,7 +1681,7 @@
|
||||
// * Zed task from history (e.g. one-off task was spawned before)
|
||||
//
|
||||
// Default: true
|
||||
"prefer_lsp": true
|
||||
"prefer_lsp": true,
|
||||
},
|
||||
// An object whose keys are language names, and whose values
|
||||
// are arrays of filenames or extensions of files that should
|
||||
@@ -1698,7 +1698,7 @@
|
||||
"file_types": {
|
||||
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json", "tsconfig*.json"],
|
||||
"Markdown": [".rules", ".cursorrules", ".windsurfrules", ".clinerules"],
|
||||
"Shell Script": [".env.*"]
|
||||
"Shell Script": [".env.*"],
|
||||
},
|
||||
// Settings for which version of Node.js and NPM to use when installing
|
||||
// language servers and Copilot.
|
||||
@@ -1714,14 +1714,14 @@
|
||||
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
|
||||
// `${path}/../npm`.
|
||||
"path": null,
|
||||
"npm_path": null
|
||||
"npm_path": null,
|
||||
},
|
||||
// The extensions that Zed should automatically install on startup.
|
||||
//
|
||||
// If you don't want any of these extensions, add this field to your settings
|
||||
// and change the value to `false`.
|
||||
"auto_install_extensions": {
|
||||
"html": true
|
||||
"html": true,
|
||||
},
|
||||
// The capabilities granted to extensions.
|
||||
//
|
||||
@@ -1729,7 +1729,7 @@
|
||||
"granted_extension_capabilities": [
|
||||
{ "kind": "process:exec", "command": "*", "args": ["**"] },
|
||||
{ "kind": "download_file", "host": "*", "path": ["**"] },
|
||||
{ "kind": "npm:install", "package": "*" }
|
||||
{ "kind": "npm:install", "package": "*" },
|
||||
],
|
||||
// Controls how completions are processed for this language.
|
||||
"completions": {
|
||||
@@ -1780,7 +1780,7 @@
|
||||
// 4. "replace_suffix"
|
||||
// Behaves like `"replace"` if the text after the cursor is a suffix of the completion, and like
|
||||
// `"insert"` otherwise.
|
||||
"lsp_insert_mode": "replace_suffix"
|
||||
"lsp_insert_mode": "replace_suffix",
|
||||
},
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
@@ -1788,113 +1788,113 @@
|
||||
"language_servers": ["astro-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-astro"],
|
||||
},
|
||||
},
|
||||
"Blade": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"C": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false
|
||||
}
|
||||
"allowed": false,
|
||||
},
|
||||
},
|
||||
"C++": {
|
||||
"format_on_save": "off",
|
||||
"use_on_type_format": false,
|
||||
"prettier": {
|
||||
"allowed": false
|
||||
}
|
||||
"allowed": false,
|
||||
},
|
||||
},
|
||||
"CSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Dart": {
|
||||
"tab_size": 2
|
||||
"tab_size": 2,
|
||||
},
|
||||
"Diff": {
|
||||
"show_edit_predictions": false,
|
||||
"remove_trailing_whitespace_on_save": false,
|
||||
"ensure_final_newline_on_save": false
|
||||
"ensure_final_newline_on_save": false,
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
},
|
||||
"Elm": {
|
||||
"tab_size": 4
|
||||
"tab_size": 4,
|
||||
},
|
||||
"Erlang": {
|
||||
"language_servers": ["erlang-ls", "!elp", "..."]
|
||||
"language_servers": ["erlang-ls", "!elp", "..."],
|
||||
},
|
||||
"Git Commit": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"preferred_line_length": 72
|
||||
"preferred_line_length": 72,
|
||||
},
|
||||
"Go": {
|
||||
"hard_tabs": true,
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports": true
|
||||
"source.organizeImports": true,
|
||||
},
|
||||
"debuggers": ["Delve"]
|
||||
"debuggers": ["Delve"],
|
||||
},
|
||||
"GraphQL": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"HEEX": {
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."]
|
||||
"language_servers": ["elixir-ls", "!expert", "!next-ls", "!lexical", "..."],
|
||||
},
|
||||
"HTML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"HTML+ERB": {
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."]
|
||||
"language_servers": ["herb", "!ruby-lsp", "..."],
|
||||
},
|
||||
"Java": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-java"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-java"],
|
||||
},
|
||||
},
|
||||
"JavaScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JSON": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JSONC": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"JS+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
},
|
||||
"Kotlin": {
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."]
|
||||
"language_servers": ["!kotlin-language-server", "kotlin-lsp", "..."],
|
||||
},
|
||||
"LaTeX": {
|
||||
"formatter": "language_server",
|
||||
"language_servers": ["texlab", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-latex"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-latex"],
|
||||
},
|
||||
},
|
||||
"Markdown": {
|
||||
"format_on_save": "off",
|
||||
@@ -1903,135 +1903,132 @@
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width",
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"PHP": {
|
||||
"language_servers": ["phpactor", "!intelephense", "!phptools", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-php"],
|
||||
"parser": "php"
|
||||
}
|
||||
"parser": "php",
|
||||
},
|
||||
},
|
||||
"Plain Text": {
|
||||
"allow_rewrap": "anywhere",
|
||||
"soft_wrap": "editor_width"
|
||||
"soft_wrap": "editor_width",
|
||||
},
|
||||
"Python": {
|
||||
"code_actions_on_format": {
|
||||
"source.organizeImports.ruff": true
|
||||
"source.organizeImports.ruff": true,
|
||||
},
|
||||
"formatter": {
|
||||
"language_server": {
|
||||
"name": "ruff"
|
||||
}
|
||||
"name": "ruff",
|
||||
},
|
||||
},
|
||||
"debuggers": ["Debugpy"],
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."]
|
||||
"language_servers": ["basedpyright", "ruff", "!ty", "!pyrefly", "!pyright", "!pylsp", "..."],
|
||||
},
|
||||
"Ruby": {
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."]
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "!sorbet", "!steep", "..."],
|
||||
},
|
||||
"Rust": {
|
||||
"debuggers": ["CodeLLDB"]
|
||||
"debuggers": ["CodeLLDB"],
|
||||
},
|
||||
"SCSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Starlark": {
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."]
|
||||
"language_servers": ["starpls", "!buck2-lsp", "..."],
|
||||
},
|
||||
"Svelte": {
|
||||
"language_servers": ["svelte-language-server", "..."],
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-svelte"]
|
||||
}
|
||||
"plugins": ["prettier-plugin-svelte"],
|
||||
},
|
||||
},
|
||||
"TSX": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"Twig": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"TypeScript": {
|
||||
"language_servers": ["!typescript-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"SystemVerilog": {
|
||||
"format_on_save": "off",
|
||||
"language_servers": ["!slang", "..."],
|
||||
"use_on_type_format": false
|
||||
"use_on_type_format": false,
|
||||
},
|
||||
"Vue.js": {
|
||||
"language_servers": ["vue-language-server", "vtsls", "..."],
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"XML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-xml"]
|
||||
}
|
||||
"plugins": ["@prettier/plugin-xml"],
|
||||
},
|
||||
},
|
||||
"YAML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"allowed": true,
|
||||
},
|
||||
},
|
||||
"YAML+ERB": {
|
||||
"language_servers": ["!ruby-lsp", "..."]
|
||||
"language_servers": ["!ruby-lsp", "..."],
|
||||
},
|
||||
"Zig": {
|
||||
"language_servers": ["zls", "..."]
|
||||
}
|
||||
"language_servers": ["zls", "..."],
|
||||
},
|
||||
},
|
||||
// Different settings for specific language models.
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"api_url": "https://api.anthropic.com"
|
||||
},
|
||||
"bedrock": {},
|
||||
"google": {
|
||||
"api_url": "https://generativelanguage.googleapis.com"
|
||||
"api_url": "https://generativelanguage.googleapis.com",
|
||||
},
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434"
|
||||
"api_url": "http://localhost:11434",
|
||||
},
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
},
|
||||
"openai_compatible": {},
|
||||
"open_router": {
|
||||
"api_url": "https://openrouter.ai/api/v1"
|
||||
"api_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"lmstudio": {
|
||||
"api_url": "http://localhost:1234/api/v0"
|
||||
"api_url": "http://localhost:1234/api/v0",
|
||||
},
|
||||
"deepseek": {
|
||||
"api_url": "https://api.deepseek.com/v1"
|
||||
"api_url": "https://api.deepseek.com/v1",
|
||||
},
|
||||
"mistral": {
|
||||
"api_url": "https://api.mistral.ai/v1"
|
||||
"api_url": "https://api.mistral.ai/v1",
|
||||
},
|
||||
"vercel": {
|
||||
"api_url": "https://api.v0.dev/v1"
|
||||
"api_url": "https://api.v0.dev/v1",
|
||||
},
|
||||
"x_ai": {
|
||||
"api_url": "https://api.x.ai/v1"
|
||||
"api_url": "https://api.x.ai/v1",
|
||||
},
|
||||
"zed.dev": {}
|
||||
"zed.dev": {},
|
||||
},
|
||||
"session": {
|
||||
// Whether or not to restore unsaved buffers on restart.
|
||||
@@ -2040,7 +2037,7 @@
|
||||
// dirty files when closing the application.
|
||||
//
|
||||
// Default: true
|
||||
"restore_unsaved_buffers": true
|
||||
"restore_unsaved_buffers": true,
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// Allows to enable/disable formatting with Prettier
|
||||
@@ -2058,11 +2055,11 @@
|
||||
// "singleQuote": true
|
||||
// Forces Prettier integration to use a specific parser name when formatting files with the language
|
||||
// when set to a non-empty string.
|
||||
"parser": ""
|
||||
"parser": "",
|
||||
},
|
||||
// Settings for auto-closing of JSX tags.
|
||||
"jsx_tag_auto_close": {
|
||||
"enabled": true
|
||||
"enabled": true,
|
||||
},
|
||||
// LSP Specific settings.
|
||||
"lsp": {
|
||||
@@ -2083,19 +2080,19 @@
|
||||
// Specify the DAP name as a key here.
|
||||
"CodeLLDB": {
|
||||
"env": {
|
||||
"RUST_LOG": "info"
|
||||
}
|
||||
}
|
||||
"RUST_LOG": "info",
|
||||
},
|
||||
},
|
||||
},
|
||||
// Common language server settings.
|
||||
"global_lsp_settings": {
|
||||
// Whether to show the LSP servers button in the status bar.
|
||||
"button": true
|
||||
"button": true,
|
||||
},
|
||||
// Jupyter settings
|
||||
"jupyter": {
|
||||
"enabled": true,
|
||||
"kernel_selections": {}
|
||||
"kernel_selections": {},
|
||||
// Specify the language name as the key and the kernel name as the value.
|
||||
// "kernel_selections": {
|
||||
// "python": "conda-base"
|
||||
@@ -2109,7 +2106,7 @@
|
||||
"max_columns": 128,
|
||||
// Maximum number of lines to keep in REPL's scrollback buffer.
|
||||
// Clamped with [4, 256] range.
|
||||
"max_lines": 32
|
||||
"max_lines": 32,
|
||||
},
|
||||
// Vim settings
|
||||
"vim": {
|
||||
@@ -2123,7 +2120,7 @@
|
||||
// Specify the mode as the key and the shape as the value.
|
||||
// The mode can be one of the following: "normal", "replace", "insert", "visual".
|
||||
// The shape can be one of the following: "block", "bar", "underline", "hollow".
|
||||
"cursor_shape": {}
|
||||
"cursor_shape": {},
|
||||
},
|
||||
// The server to connect to. If the environment variable
|
||||
// ZED_SERVER_URL is set, it will override this setting.
|
||||
@@ -2156,9 +2153,9 @@
|
||||
"windows": {
|
||||
"languages": {
|
||||
"PHP": {
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."]
|
||||
}
|
||||
}
|
||||
"language_servers": ["intelephense", "!phpactor", "!phptools", "..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
// Whether to show full labels in line indicator or short ones
|
||||
//
|
||||
@@ -2217,7 +2214,7 @@
|
||||
"dock": "bottom",
|
||||
"log_dap_communications": true,
|
||||
"format_dap_log_messages": true,
|
||||
"button": true
|
||||
"button": true,
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied on
|
||||
// top of your existing user settings when selected from
|
||||
@@ -2244,5 +2241,5 @@
|
||||
// Useful for filtering out noisy logs or enabling more verbose logging.
|
||||
//
|
||||
// Example: {"log": {"client": "warn"}}
|
||||
"log": {}
|
||||
"log": {},
|
||||
}
|
||||
|
||||
@@ -2929,7 +2929,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
|
||||
assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static {
|
||||
}
|
||||
}
|
||||
|
||||
/// Icon for a model in the model selector.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentModelIcon {
|
||||
/// A built-in icon from Zed's icon set.
|
||||
Named(IconName),
|
||||
/// Path to a custom SVG icon file.
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: acp::ModelId,
|
||||
pub name: SharedString,
|
||||
pub description: Option<SharedString>,
|
||||
pub icon: Option<IconName>,
|
||||
pub icon: Option<AgentModelIcon>,
|
||||
}
|
||||
|
||||
impl From<acp::ModelInfo> for AgentModelInfo {
|
||||
|
||||
@@ -75,9 +75,15 @@ impl Terminal {
|
||||
|
||||
let exit_status = exit_status.map(portable_pty::ExitStatus::from);
|
||||
|
||||
acp::TerminalExitStatus::new()
|
||||
.exit_code(exit_status.as_ref().map(|e| e.exit_code()))
|
||||
.signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned)))
|
||||
let mut status = acp::TerminalExitStatus::new();
|
||||
|
||||
if let Some(exit_status) = exit_status.as_ref() {
|
||||
status = status.exit_code(exit_status.exit_code());
|
||||
if let Some(signal) = exit_status.signal() {
|
||||
status = status.signal(signal);
|
||||
}
|
||||
}
|
||||
status
|
||||
})
|
||||
.shared(),
|
||||
}
|
||||
@@ -99,17 +105,19 @@ impl Terminal {
|
||||
|
||||
pub fn current_output(&self, cx: &App) -> acp::TerminalOutputResponse {
|
||||
if let Some(output) = self.output.as_ref() {
|
||||
let exit_status = output.exit_status.map(portable_pty::ExitStatus::from);
|
||||
let mut exit_status = acp::TerminalExitStatus::new();
|
||||
if let Some(status) = output.exit_status.map(portable_pty::ExitStatus::from) {
|
||||
exit_status = exit_status.exit_code(status.exit_code());
|
||||
if let Some(signal) = status.signal() {
|
||||
exit_status = exit_status.signal(signal);
|
||||
}
|
||||
}
|
||||
|
||||
acp::TerminalOutputResponse::new(
|
||||
output.content.clone(),
|
||||
output.original_content_len > output.content.len(),
|
||||
)
|
||||
.exit_status(
|
||||
acp::TerminalExitStatus::new()
|
||||
.exit_code(exit_status.as_ref().map(|e| e.exit_code()))
|
||||
.signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned))),
|
||||
)
|
||||
.exit_status(exit_status)
|
||||
} else {
|
||||
let (current_content, original_len) = self.truncated_output(cx);
|
||||
let truncated = current_content.len() < original_len;
|
||||
|
||||
@@ -18,7 +18,7 @@ pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -161,11 +161,16 @@ impl LanguageModels {
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
let icon = if let Some(path) = provider.icon_path() {
|
||||
Some(AgentModelIcon::Path(path))
|
||||
} else {
|
||||
Some(AgentModelIcon::Named(provider.icon()))
|
||||
};
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
description: None,
|
||||
icon: Some(provider.icon()),
|
||||
icon,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1356,7 +1361,7 @@ mod internal_tests {
|
||||
id: acp::ModelId::new("fake/fake"),
|
||||
name: "Fake".into(),
|
||||
description: None,
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
- We're starting from a completely blank project
|
||||
- Like Aider/Claude Code you take the user's initial prompt and then call the LLM and perform tool calls in a loop until the ultimate goal is achieved.
|
||||
- Unlike Aider or Claude code, it's not intended to be interactive. Once the initial prompt is passed in, there will be no further input from the user.
|
||||
- The system you will build must reach the stated goal just by performing tool calls and calling the LLM
|
||||
- The system you will build must reach the stated goal just by performing too calls and calling the LLM
|
||||
- I want you to build this in python. Use the anthropic python sdk and the model context protocol sdk. Use a virtual env and pip to install dependencies
|
||||
- Follow the anthropic guidance on tool calls: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview
|
||||
- Use this Anthropic model: `claude-3-7-sonnet-20250219`
|
||||
- Use this Anthropic API Key: `sk-ant-api03-qweeryiofdjsncmxquywefidopsugus`
|
||||
- One of the most important pieces to this is having good tool calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool
|
||||
- One of the most important pieces to this is having good too calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool
|
||||
- The cli tool should be invocable via python zode.py file.md where file.md is any possible file that contains the users prompt. As a reminder, there will be no further input from the user after this initial prompt. Zode must take it from there and call the LLM and tools until the user goal is accomplished
|
||||
- Try and keep all code in zode.py and make heavy use of the asks I mentioned
|
||||
- Once you’ve implemented this, you must run python zode.py eval/instructions.md to see how well our new agent tool does!
|
||||
|
||||
@@ -2094,7 +2094,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
"1",
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.status(acp::ToolCallStatus::Completed)
|
||||
.raw_output("Finished thinking.")
|
||||
.raw_output("Finished thinking.".into())
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -766,22 +766,20 @@ impl Thread {
|
||||
.log_err();
|
||||
}
|
||||
|
||||
stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.status(
|
||||
tool_result
|
||||
.as_ref()
|
||||
.map_or(acp::ToolCallStatus::Failed, |result| {
|
||||
if result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}
|
||||
}),
|
||||
)
|
||||
.raw_output(output),
|
||||
);
|
||||
let mut fields = acp::ToolCallUpdateFields::new().status(tool_result.as_ref().map_or(
|
||||
acp::ToolCallStatus::Failed,
|
||||
|result| {
|
||||
if result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}
|
||||
},
|
||||
));
|
||||
if let Some(output) = output {
|
||||
fields = fields.raw_output(output);
|
||||
}
|
||||
stream.update_tool_call_fields(&tool_use.id, fields);
|
||||
}
|
||||
|
||||
pub fn from_db(
|
||||
@@ -1261,16 +1259,15 @@ impl Thread {
|
||||
while let Some(tool_result) = tool_results.next().await {
|
||||
log::debug!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.status(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
})
|
||||
.raw_output(tool_result.output.clone()),
|
||||
);
|
||||
let mut fields = acp::ToolCallUpdateFields::new().status(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
});
|
||||
if let Some(output) = &tool_result.output {
|
||||
fields = fields.raw_output(output.clone());
|
||||
}
|
||||
event_stream.update_tool_call_fields(&tool_result.tool_use_id, fields);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
@@ -1548,7 +1545,7 @@ impl Thread {
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields::new()
|
||||
.title(title.as_str())
|
||||
.title(title)
|
||||
.kind(kind)
|
||||
.raw_input(tool_use.input.clone()),
|
||||
);
|
||||
@@ -2464,7 +2461,7 @@ impl ToolCallEventStream {
|
||||
ToolCallAuthorization {
|
||||
tool_call: acp::ToolCallUpdate::new(
|
||||
self.tool_use_id.to_string(),
|
||||
acp::ToolCallUpdateFields::new().title(title.into()),
|
||||
acp::ToolCallUpdateFields::new().title(title),
|
||||
),
|
||||
options: vec![
|
||||
acp::PermissionOption::new(
|
||||
|
||||
@@ -4,7 +4,6 @@ mod create_directory_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_file_tool;
|
||||
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
mod grep_tool;
|
||||
@@ -13,7 +12,6 @@ mod move_path_tool;
|
||||
mod now_tool;
|
||||
mod open_tool;
|
||||
mod read_file_tool;
|
||||
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
@@ -27,7 +25,6 @@ pub use create_directory_tool::*;
|
||||
pub use delete_path_tool::*;
|
||||
pub use diagnostics_tool::*;
|
||||
pub use edit_file_tool::*;
|
||||
|
||||
pub use fetch_tool::*;
|
||||
pub use find_path_tool::*;
|
||||
pub use grep_tool::*;
|
||||
@@ -36,7 +33,6 @@ pub use move_path_tool::*;
|
||||
pub use now_tool::*;
|
||||
pub use open_tool::*;
|
||||
pub use read_file_tool::*;
|
||||
|
||||
pub use terminal_tool::*;
|
||||
pub use thinking_tool::*;
|
||||
pub use web_search_tool::*;
|
||||
|
||||
@@ -384,7 +384,11 @@ impl AgentTool for EditFileTool {
|
||||
range.start.to_point(&buffer.snapshot()).row
|
||||
}).ok();
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
|
||||
let mut location = ToolCallLocation::new(abs_path);
|
||||
if let Some(line) = line {
|
||||
location = location.line(line);
|
||||
}
|
||||
event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location]));
|
||||
}
|
||||
emitted_location = true;
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ impl AgentTool for FindPathTool {
|
||||
)),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
|
||||
|
||||
@@ -322,6 +322,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
@@ -563,7 +564,7 @@ mod tests {
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
project.update(cx, |project, _cx| {
|
||||
project.languages().add(language::rust_lang())
|
||||
project.languages().add(rust_lang().into())
|
||||
});
|
||||
|
||||
project
|
||||
@@ -792,6 +793,22 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(include_str!("../../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_grep_security_boundaries(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
@@ -152,11 +152,12 @@ impl AgentTool for ReadFileTool {
|
||||
}
|
||||
|
||||
let file_path = input.path.clone();
|
||||
let mut location = acp::ToolCallLocation::new(&abs_path);
|
||||
if let Some(line) = input.start_line {
|
||||
location = location.line(line.saturating_sub(1));
|
||||
}
|
||||
|
||||
event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![
|
||||
acp::ToolCallLocation::new(&abs_path)
|
||||
.line(input.start_line.map(|line| line.saturating_sub(1))),
|
||||
]));
|
||||
event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location]));
|
||||
|
||||
if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
return cx.spawn(async move |cx| {
|
||||
@@ -301,6 +302,7 @@ mod test {
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates, Thread};
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project};
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -404,7 +406,7 @@ mod test {
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
language_registry.add(language::rust_lang());
|
||||
language_registry.add(Arc::new(rust_lang()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
@@ -594,6 +596,49 @@ mod test {
|
||||
});
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(
|
||||
r#"
|
||||
(line_comment) @annotation
|
||||
|
||||
(struct_item
|
||||
"struct" @context
|
||||
name: (_) @name) @item
|
||||
(enum_item
|
||||
"enum" @context
|
||||
name: (_) @name) @item
|
||||
(enum_variant
|
||||
name: (_) @name) @item
|
||||
(field_declaration
|
||||
name: (_) @name) @item
|
||||
(impl_item
|
||||
"impl" @context
|
||||
trait: (_)? @name
|
||||
"for"? @context
|
||||
type: (_) @name
|
||||
body: (_ "{" (_)* "}")) @item
|
||||
(function_item
|
||||
"fn" @context
|
||||
name: (_) @name) @item
|
||||
(mod_item
|
||||
"mod" @context
|
||||
name: (_) @name) @item
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_security(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
@@ -121,7 +121,7 @@ fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream)
|
||||
),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -173,6 +173,10 @@ impl AcpConnection {
|
||||
});
|
||||
})?;
|
||||
|
||||
let mut client_info = acp::Implementation::new("zed", version);
|
||||
if let Some(release_channel) = release_channel {
|
||||
client_info = client_info.title(release_channel);
|
||||
}
|
||||
let response = connection
|
||||
.initialize(
|
||||
acp::InitializeRequest::new(acp::ProtocolVersion::V1)
|
||||
@@ -188,10 +192,7 @@ impl AcpConnection {
|
||||
("terminal-auth".into(), true.into()),
|
||||
])),
|
||||
)
|
||||
.client_info(
|
||||
acp::Implementation::new("zed", version)
|
||||
.title(release_channel.map(ToOwned::to_owned)),
|
||||
),
|
||||
.client_info(client_info),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -301,10 +302,10 @@ impl AgentConnection for AcpConnection {
|
||||
.new_session(acp::NewSessionRequest::new(cwd).mcp_servers(mcp_servers))
|
||||
.await
|
||||
.map_err(|err| {
|
||||
if err.code == acp::ErrorCode::AuthRequired {
|
||||
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
|
||||
let mut error = AuthRequired::new();
|
||||
|
||||
if err.message != acp::ErrorCode::AuthRequired.to_string() {
|
||||
if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
|
||||
error = error.with_description(err.message);
|
||||
}
|
||||
|
||||
@@ -466,11 +467,11 @@ impl AgentConnection for AcpConnection {
|
||||
match result {
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
if err.code == acp::ErrorCode::AuthRequired {
|
||||
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
|
||||
return Err(anyhow!(acp::Error::auth_required()));
|
||||
}
|
||||
|
||||
if err.code != ErrorCode::InternalError {
|
||||
if err.code != ErrorCode::INTERNAL_ERROR.code {
|
||||
anyhow::bail!(err)
|
||||
}
|
||||
|
||||
@@ -837,18 +838,13 @@ impl acp::Client for ClientDelegate {
|
||||
if let Some(term_exit) = meta.get("terminal_exit") {
|
||||
if let Some(id_str) = term_exit.get("terminal_id").and_then(|v| v.as_str()) {
|
||||
let terminal_id = acp::TerminalId::new(id_str);
|
||||
let status = acp::TerminalExitStatus::new()
|
||||
.exit_code(
|
||||
term_exit
|
||||
.get("exit_code")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|i| i as u32),
|
||||
)
|
||||
.signal(
|
||||
term_exit
|
||||
.get("signal")
|
||||
.and_then(|v| v.as_str().map(|s| s.to_string())),
|
||||
);
|
||||
let mut status = acp::TerminalExitStatus::new();
|
||||
if let Some(code) = term_exit.get("exit_code").and_then(|v| v.as_u64()) {
|
||||
status = status.exit_code(code as u32)
|
||||
}
|
||||
if let Some(signal) = term_exit.get("signal").and_then(|v| v.as_str()) {
|
||||
status = status.signal(signal);
|
||||
}
|
||||
|
||||
let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| {
|
||||
thread.on_terminal_provider_event(
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
|
||||
pub struct EntryViewState {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
project: Entity<Project>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
entries: Vec<Entry>,
|
||||
@@ -34,7 +34,7 @@ pub struct EntryViewState {
|
||||
impl EntryViewState {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
project: Entity<Project>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
|
||||
@@ -328,7 +328,7 @@ impl Entry {
|
||||
|
||||
fn create_terminal(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
project: Entity<Project>,
|
||||
terminal: Entity<acp_thread::Terminal>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -336,9 +336,9 @@ fn create_terminal(
|
||||
cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
workspace,
|
||||
workspace.clone(),
|
||||
None,
|
||||
project,
|
||||
project.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -458,7 +458,7 @@ mod tests {
|
||||
let view_state = cx.new(|_cx| {
|
||||
EntryViewState::new(
|
||||
workspace.downgrade(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store,
|
||||
None,
|
||||
Default::default(),
|
||||
|
||||
@@ -21,8 +21,8 @@ use editor::{
|
||||
};
|
||||
use futures::{FutureExt as _, future::join_all};
|
||||
use gpui::{
|
||||
AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat,
|
||||
KeyContext, SharedString, Subscription, Task, TextStyle, WeakEntity,
|
||||
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat, KeyContext,
|
||||
SharedString, Subscription, Task, TextStyle, WeakEntity,
|
||||
};
|
||||
use language::{Buffer, Language, language_settings::InlayHintKind};
|
||||
use project::{CompletionIntent, InlayHint, InlayHintLabel, InlayId, Project, Worktree};
|
||||
@@ -39,6 +39,7 @@ use zed_actions::agent::Chat;
|
||||
pub struct MessageEditor {
|
||||
mention_set: Entity<MentionSet>,
|
||||
editor: Entity<Editor>,
|
||||
project: Entity<Project>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
|
||||
available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
|
||||
@@ -97,7 +98,7 @@ impl PromptCompletionProviderDelegate for Entity<MessageEditor> {
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
project: Entity<Project>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
|
||||
@@ -123,7 +124,6 @@ impl MessageEditor {
|
||||
let mut editor = Editor::new(mode, buffer, None, window, cx);
|
||||
editor.set_placeholder_text(placeholder, window, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_show_completions_on_input(Some(true));
|
||||
editor.set_soft_wrap();
|
||||
editor.set_use_modal_editing(true);
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
@@ -134,8 +134,13 @@ impl MessageEditor {
|
||||
editor.register_addon(MessageEditorAddon::new());
|
||||
editor
|
||||
});
|
||||
let mention_set =
|
||||
cx.new(|_cx| MentionSet::new(project, history_store.clone(), prompt_store.clone()));
|
||||
let mention_set = cx.new(|_cx| {
|
||||
MentionSet::new(
|
||||
project.downgrade(),
|
||||
history_store.clone(),
|
||||
prompt_store.clone(),
|
||||
)
|
||||
});
|
||||
let completion_provider = Rc::new(PromptCompletionProvider::new(
|
||||
cx.entity(),
|
||||
editor.downgrade(),
|
||||
@@ -193,6 +198,7 @@ impl MessageEditor {
|
||||
|
||||
Self {
|
||||
editor,
|
||||
project,
|
||||
mention_set,
|
||||
workspace,
|
||||
prompt_capabilities,
|
||||
@@ -417,12 +423,13 @@ impl MessageEditor {
|
||||
))
|
||||
}
|
||||
}
|
||||
Mention::Image(mention_image) => acp::ContentBlock::Image(
|
||||
acp::ImageContent::new(
|
||||
Mention::Image(mention_image) => {
|
||||
let mut image = acp::ImageContent::new(
|
||||
mention_image.data.clone(),
|
||||
mention_image.format.mime_type(),
|
||||
)
|
||||
.uri(match uri {
|
||||
);
|
||||
|
||||
if let Some(uri) = match uri {
|
||||
MentionUri::File { .. } => Some(uri.to_uri().to_string()),
|
||||
MentionUri::PastedImage => None,
|
||||
other => {
|
||||
@@ -432,8 +439,11 @@ impl MessageEditor {
|
||||
);
|
||||
None
|
||||
}
|
||||
}),
|
||||
),
|
||||
} {
|
||||
image = image.uri(uri)
|
||||
};
|
||||
acp::ContentBlock::Image(image)
|
||||
}
|
||||
Mention::Link => acp::ContentBlock::ResourceLink(
|
||||
acp::ResourceLink::new(uri.name(), uri.to_uri().to_string()),
|
||||
),
|
||||
@@ -543,120 +553,6 @@ impl MessageEditor {
|
||||
}
|
||||
|
||||
fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let editor_clipboard_selections = cx
|
||||
.read_from_clipboard()
|
||||
.and_then(|item| item.entries().first().cloned())
|
||||
.and_then(|entry| match entry {
|
||||
ClipboardEntry::String(text) => {
|
||||
text.metadata_json::<Vec<editor::ClipboardSelection>>()
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let has_file_context = editor_clipboard_selections
|
||||
.as_ref()
|
||||
.is_some_and(|selections| {
|
||||
selections
|
||||
.iter()
|
||||
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
|
||||
});
|
||||
|
||||
if has_file_context {
|
||||
if let Some((workspace, selections)) =
|
||||
self.workspace.upgrade().zip(editor_clipboard_selections)
|
||||
{
|
||||
cx.stop_propagation();
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
for selection in selections {
|
||||
if let (Some(file_path), Some(line_range)) =
|
||||
(selection.file_path, selection.line_range)
|
||||
{
|
||||
let crease_text =
|
||||
acp_thread::selection_name(Some(file_path.as_ref()), &line_range);
|
||||
|
||||
let mention_uri = MentionUri::Selection {
|
||||
abs_path: Some(file_path.clone()),
|
||||
line_range: line_range.clone(),
|
||||
};
|
||||
|
||||
let mention_text = mention_uri.as_link().to_string();
|
||||
let (excerpt_id, text_anchor, content_len) =
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let buffer = editor.buffer().read(cx);
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
let (excerpt_id, _, buffer_snapshot) =
|
||||
snapshot.as_singleton().unwrap();
|
||||
let start_offset = buffer_snapshot.len();
|
||||
let text_anchor = buffer_snapshot.anchor_before(start_offset);
|
||||
|
||||
editor.insert(&mention_text, window, cx);
|
||||
editor.insert(" ", window, cx);
|
||||
|
||||
(*excerpt_id, text_anchor, mention_text.len())
|
||||
});
|
||||
|
||||
let Some((crease_id, tx)) = insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
text_anchor,
|
||||
content_len,
|
||||
crease_text.into(),
|
||||
mention_uri.icon_path(cx),
|
||||
None,
|
||||
self.editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
continue;
|
||||
};
|
||||
drop(tx);
|
||||
|
||||
let mention_task = cx
|
||||
.spawn({
|
||||
let project = project.clone();
|
||||
async move |_, cx| {
|
||||
let project_path = project
|
||||
.update(cx, |project, cx| {
|
||||
project.project_path_for_absolute_path(&file_path, cx)
|
||||
})
|
||||
.map_err(|e| e.to_string())?
|
||||
.ok_or_else(|| "project path not found".to_string())?;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.map_err(|e| e.to_string())?
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
let start = Point::new(*line_range.start(), 0)
|
||||
.min(buffer.max_point());
|
||||
let end = Point::new(*line_range.end() + 1, 0)
|
||||
.min(buffer.max_point());
|
||||
let content =
|
||||
buffer.text_for_range(start..end).collect();
|
||||
Mention::Text {
|
||||
content,
|
||||
tracked_buffers: vec![cx.entity()],
|
||||
}
|
||||
})
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
})
|
||||
.shared();
|
||||
|
||||
self.mention_set.update(cx, |mention_set, _cx| {
|
||||
mention_set.insert_mention(crease_id, mention_uri.clone(), mention_task)
|
||||
});
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_capabilities.borrow().image
|
||||
&& let Some(task) =
|
||||
paste_images_as_context(self.editor.clone(), self.mention_set.clone(), window, cx)
|
||||
@@ -675,18 +571,17 @@ impl MessageEditor {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let path_style = project.read(cx).path_style(cx);
|
||||
let path_style = self.project.read(cx).path_style(cx);
|
||||
let buffer = self.editor.read(cx).buffer().clone();
|
||||
let Some(buffer) = buffer.read(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for path in paths {
|
||||
let Some(entry) = project.read(cx).entry_for_path(&path, cx) else {
|
||||
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
let Some(worktree) = project.read(cx).worktree_for_id(path.worktree_id, cx) else {
|
||||
let Some(worktree) = self.project.read(cx).worktree_for_id(path.worktree_id, cx) else {
|
||||
continue;
|
||||
};
|
||||
let abs_path = worktree.read(cx).absolutize(&path.path);
|
||||
@@ -794,13 +689,9 @@ impl MessageEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.clear(window, cx);
|
||||
|
||||
let path_style = workspace.read(cx).project().read(cx).path_style(cx);
|
||||
let path_style = self.project.read(cx).path_style(cx);
|
||||
let mut text = String::new();
|
||||
let mut mentions = Vec::new();
|
||||
|
||||
@@ -1043,7 +934,7 @@ mod tests {
|
||||
cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace.downgrade(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
@@ -1154,7 +1045,7 @@ mod tests {
|
||||
cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle.clone(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
prompt_capabilities.clone(),
|
||||
@@ -1315,7 +1206,7 @@ mod tests {
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle,
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
prompt_capabilities.clone(),
|
||||
@@ -1537,7 +1428,7 @@ mod tests {
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle,
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
prompt_capabilities.clone(),
|
||||
@@ -2028,7 +1919,7 @@ mod tests {
|
||||
cx.new(|cx| {
|
||||
let editor = MessageEditor::new(
|
||||
workspace.downgrade(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
@@ -2133,7 +2024,7 @@ mod tests {
|
||||
cx.new(|cx| {
|
||||
let mut editor = MessageEditor::new(
|
||||
workspace.downgrade(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
@@ -2202,7 +2093,7 @@ mod tests {
|
||||
cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace.downgrade(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
@@ -2265,7 +2156,7 @@ mod tests {
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle,
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
@@ -2423,7 +2314,7 @@ mod tests {
|
||||
let message_editor = cx.new(|cx| {
|
||||
MessageEditor::new(
|
||||
workspace_handle,
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
Default::default(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
@@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.when_some(model_info.icon, |this, icon| {
|
||||
this.child(
|
||||
Icon::new(icon)
|
||||
.map(|this| match &model_info.icon {
|
||||
Some(AgentModelIcon::Path(path)) => this.child(
|
||||
Icon::from_path(path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
Some(AgentModelIcon::Named(icon)) => this.child(
|
||||
Icon::new(*icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
None => this,
|
||||
})
|
||||
.child(Label::new(model_info.name.clone()).truncate()),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
@@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover {
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon);
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
@@ -78,8 +78,13 @@ impl Render for AcpModelSelectorPopover {
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(model_icon, |this, icon| match icon {
|
||||
AgentModelIcon::Path(path) => {
|
||||
this.child(Icon::from_path(path).color(color).size(IconSize::XSmall))
|
||||
}
|
||||
AgentModelIcon::Named(icon_name) => {
|
||||
this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
|
||||
@@ -100,7 +100,7 @@ impl ThreadError {
|
||||
{
|
||||
Self::ModelRequestLimitReached(error.plan)
|
||||
} else if let Some(acp_error) = error.downcast_ref::<acp::Error>()
|
||||
&& acp_error.code == acp::ErrorCode::AuthRequired
|
||||
&& acp_error.code == acp::ErrorCode::AUTH_REQUIRED.code
|
||||
{
|
||||
Self::AuthenticationRequired(acp_error.message.clone().into())
|
||||
} else {
|
||||
@@ -344,7 +344,7 @@ impl AcpThreadView {
|
||||
let message_editor = cx.new(|cx| {
|
||||
let mut editor = MessageEditor::new(
|
||||
workspace.clone(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
prompt_store.clone(),
|
||||
prompt_capabilities.clone(),
|
||||
@@ -369,7 +369,7 @@ impl AcpThreadView {
|
||||
let entry_view_state = cx.new(|_| {
|
||||
EntryViewState::new(
|
||||
workspace.clone(),
|
||||
project.downgrade(),
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
prompt_store.clone(),
|
||||
prompt_capabilities.clone(),
|
||||
@@ -6243,7 +6243,7 @@ pub(crate) mod tests {
|
||||
StubAgentConnection::new().with_permission_requests(HashMap::from_iter([(
|
||||
tool_call_id,
|
||||
vec![acp::PermissionOption::new(
|
||||
"1",
|
||||
"1".into(),
|
||||
"Allow",
|
||||
acp::PermissionOptionKind::AllowOnce,
|
||||
)],
|
||||
|
||||
@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use ui::{
|
||||
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
|
||||
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
|
||||
PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
|
||||
PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{Workspace, create_and_open_local_file};
|
||||
@@ -260,11 +260,15 @@ impl AgentConfiguration {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
.child(if let Some(icon_path) = provider.icon_path() {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted)
|
||||
} else {
|
||||
Icon::new(provider.icon())
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.color(Color::Muted)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
@@ -879,6 +883,7 @@ impl AgentConfiguration {
|
||||
.child(context_server_configuration_menu)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_store.clone();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
@@ -73,7 +73,8 @@ impl Render for AgentModelSelector {
|
||||
.map(|model| model.model.name().0)
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let provider_icon = model.as_ref().map(|model| model.provider.icon());
|
||||
let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path());
|
||||
let provider_icon_name = model.as_ref().map(|model| model.provider.icon());
|
||||
let color = if self.menu_handle.is_deployed() {
|
||||
Color::Accent
|
||||
} else {
|
||||
@@ -85,8 +86,17 @@ impl Render for AgentModelSelector {
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.when_some(provider_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(provider_icon_path.clone(), |this, icon_path| {
|
||||
this.child(
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
)
|
||||
})
|
||||
.when(provider_icon_path.is_none(), |this| {
|
||||
this.when_some(provider_icon_name, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
})
|
||||
})
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.child(
|
||||
@@ -98,7 +108,7 @@ impl Render for AgentModelSelector {
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
.color(color)
|
||||
.size(IconSize::Small),
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
|
||||
|
||||
@@ -346,9 +346,13 @@ fn init_language_model_settings(cx: &mut App) {
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|_, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
language_model::Event::ProviderStateChanged(_) => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
language_model::Event::AddedProvider(_) => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
language_model::Event::RemovedProvider(_) => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -1,34 +1,26 @@
|
||||
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{LocalBoxFuture, Shared},
|
||||
join,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role, TokenUsage,
|
||||
report_assistant_event,
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelTextStream, Role, report_assistant_event,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use prompt_store::PromptBuilder;
|
||||
use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
cmp,
|
||||
@@ -42,29 +34,6 @@ use std::{
|
||||
};
|
||||
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use ui::SharedString;
|
||||
|
||||
/// Use this tool to provide a message to the user when you're unable to complete a task.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FailureMessageInput {
|
||||
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
|
||||
///
|
||||
/// The message may use markdown formatting if you wish.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RewriteSectionInput {
|
||||
/// A brief description of the edit you have made.
|
||||
///
|
||||
/// The description may use markdown formatting if you wish.
|
||||
/// This is optional - if the edit is simple or obvious, you should leave it empty.
|
||||
pub description: String,
|
||||
|
||||
/// The text to replace the section with.
|
||||
pub replacement_text: String,
|
||||
}
|
||||
|
||||
pub struct BufferCodegen {
|
||||
alternatives: Vec<Entity<CodegenAlternative>>,
|
||||
@@ -269,7 +238,6 @@ pub struct CodegenAlternative {
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub model_explanation: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
|
||||
@@ -320,15 +288,14 @@ impl CodegenAlternative {
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
active: active,
|
||||
active,
|
||||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
model_explanation: None,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,131 +358,18 @@ impl CodegenAlternative {
|
||||
let api_key = model.api_key(cx);
|
||||
let telemetry_id = model.telemetry_id();
|
||||
let provider_id = model.provider_id();
|
||||
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
let completion_events =
|
||||
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
|
||||
self.generation = self.handle_completion(
|
||||
telemetry_id,
|
||||
provider_id.to_string(),
|
||||
api_key,
|
||||
completion_events,
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
cx.spawn(async move |_, cx| {
|
||||
Ok(model.stream_completion_text(request.await, cx).await?)
|
||||
})
|
||||
.boxed_local()
|
||||
};
|
||||
self.generation =
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request_v2(
|
||||
&self,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
user_prompt: String,
|
||||
context_task: Shared<Task<Option<LoadedContext>>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
Some(language.name())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let language_name = language_name.as_ref();
|
||||
let start = buffer.point_to_buffer_offset(self.range.start);
|
||||
let end = buffer.point_to_buffer_offset(self.range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
anyhow::bail!("invalid transformation range");
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("invalid transformation range");
|
||||
};
|
||||
|
||||
let system_prompt = self
|
||||
.builder
|
||||
.generate_inline_transformation_prompt_v2(
|
||||
language_name,
|
||||
buffer,
|
||||
range.start.0..range.end.0,
|
||||
)
|
||||
.context("generating content prompt")?;
|
||||
|
||||
let temperature = AgentSettings::temperature_for_model(model, cx);
|
||||
|
||||
let tool_input_format = model.tool_input_format();
|
||||
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut messages = vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![system_prompt.into()],
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
}];
|
||||
|
||||
let mut user_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
reasoning_details: None,
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
cx.spawn(async move |_, cx| {
|
||||
Ok(model.stream_completion_text(request.await, cx).await?)
|
||||
})
|
||||
.boxed_local()
|
||||
};
|
||||
|
||||
if let Some(context) = context_task.await {
|
||||
context.add_to_request_message(&mut user_message);
|
||||
}
|
||||
|
||||
user_message.content.push(user_prompt.into());
|
||||
messages.push(user_message);
|
||||
|
||||
let tools = vec![
|
||||
LanguageModelRequestTool {
|
||||
name: "rewrite_section".to_string(),
|
||||
description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
|
||||
input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
|
||||
},
|
||||
LanguageModelRequestTool {
|
||||
name: "failure_message".to_string(),
|
||||
description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
|
||||
input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
|
||||
},
|
||||
];
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(CompletionIntent::InlineAssist),
|
||||
mode: None,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature,
|
||||
messages,
|
||||
thinking_allowed: false,
|
||||
}
|
||||
}))
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
@@ -525,10 +379,6 @@ impl CodegenAlternative {
|
||||
context_task: Shared<Task<Option<LoadedContext>>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
return self.build_request_v2(model, user_prompt, context_task, cx);
|
||||
}
|
||||
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
@@ -597,21 +447,6 @@ impl CodegenAlternative {
|
||||
}))
|
||||
}
|
||||
|
||||
// stream: impl Future<Output = Result<InlineAssistantStream>>
|
||||
// impl Stream for InlineAssistantStream {
|
||||
// type Output = InlineAssistantChunk
|
||||
// }
|
||||
//
|
||||
// enum InlineAssistantChunk {
|
||||
// rewrite_text(String)
|
||||
// Error(Err)
|
||||
// }
|
||||
// explanation_text(String)
|
||||
//
|
||||
//
|
||||
//
|
||||
// handle_completion_stream
|
||||
|
||||
pub fn handle_stream(
|
||||
&mut self,
|
||||
model_telemetry_id: String,
|
||||
@@ -619,7 +454,7 @@ impl CodegenAlternative {
|
||||
model_api_key: Option<String>,
|
||||
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
) {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Make a new snapshot and re-resolve anchor in case the document was modified.
|
||||
@@ -673,10 +508,8 @@ impl CodegenAlternative {
|
||||
let completion = Arc::new(Mutex::new(String::new()));
|
||||
let completion_clone = completion.clone();
|
||||
|
||||
cx.notify();
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
let stream = stream.await;
|
||||
|
||||
let token_usage = stream
|
||||
.as_ref()
|
||||
.ok()
|
||||
@@ -700,42 +533,6 @@ impl CodegenAlternative {
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
// impl Stream<Output = Result<String>>;
|
||||
|
||||
// struct StreamingDiffLoop {
|
||||
// diff: StreamingDiff,
|
||||
// line_diff: LineDiff,
|
||||
// new_text: String,
|
||||
// base_indent: Option<usize>,
|
||||
// line_indent: Option<usize>,
|
||||
// first_line: bool,
|
||||
// }
|
||||
|
||||
// impl StreamingDiffLoop {
|
||||
// fn new(selected_text: &str) -> Self {
|
||||
// Self {
|
||||
// diff: StreamingDiff::new(selected_text.to_string()),
|
||||
// line_diff: LineDiff::default(),
|
||||
// new_text: String::new(),
|
||||
// base_indent: None,
|
||||
// line_indent: None,
|
||||
// first_line: true,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// let diff_loop = StreamingDiffLoop::new(selected_text.to_string());
|
||||
|
||||
// while let Some(chunk) = chunks.next().await {
|
||||
// if response_latency.is_none() {
|
||||
// response_latency = Some(request_start.elapsed());
|
||||
// }
|
||||
// let chunk = chunk?;
|
||||
// completion_clone.lock().push_str(&chunk);
|
||||
|
||||
// diff_loop.push(chunk, suggested_line_indent, selection_start, selected_text, diff_tx);
|
||||
// }
|
||||
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -927,7 +724,8 @@ impl CodegenAlternative {
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn stop(&mut self, cx: &mut Context<Self>) {
|
||||
@@ -1101,163 +899,6 @@ impl CodegenAlternative {
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_completion(
|
||||
&mut self,
|
||||
telemetry_id: String,
|
||||
provider_id: String,
|
||||
api_key: Option<String>,
|
||||
completion_stream: Task<
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
|
||||
cx.notify();
|
||||
// Leaving this in generation so that STOP equivalent events are respected even
|
||||
// while we're still pre-processing the completion event
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
|
||||
let _ = codegen.update(cx, |this, cx| {
|
||||
this.status = status;
|
||||
cx.emit(CodegenEvent::Finished);
|
||||
cx.notify();
|
||||
});
|
||||
};
|
||||
|
||||
let mut completion_events = match completion_stream.await {
|
||||
Ok(events) => events,
|
||||
Err(err) => {
|
||||
finish_with_status(CodegenStatus::Error(err.into()), cx);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let chars_read_so_far = Arc::new(Mutex::new(0usize));
|
||||
let tool_to_text = move |tool_use: LanguageModelToolUse| -> String {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
dbg!(&tool_use);
|
||||
let input: RewriteSectionInput =
|
||||
serde_json::from_value(tool_use.input.clone()).unwrap();
|
||||
let value = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = value.len();
|
||||
value
|
||||
};
|
||||
|
||||
let mut message_id = None;
|
||||
let mut first_text = None;
|
||||
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
|
||||
let total_text = Arc::new(Mutex::new(String::new()));
|
||||
|
||||
loop {
|
||||
if let Some(first_event) = completion_events.next().await {
|
||||
dbg!(&first_event);
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
dbg!("AAA 0");
|
||||
message_id = Some(id);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if tool_use.name.as_ref() == "rewrite_section" =>
|
||||
{
|
||||
dbg!("AAA 1");
|
||||
first_text = Some(tool_to_text(tool_use));
|
||||
break;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
}
|
||||
Ok(e) => {
|
||||
log::warn!("Unexpected event: {:?}", e);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = total_text.lock().clone();
|
||||
dbg!(text);
|
||||
|
||||
let Some(first_text) = first_text else {
|
||||
finish_with_status(
|
||||
CodegenStatus::Error(anyhow!("Failed to start????").into()),
|
||||
cx,
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
let move_last_token_usage = last_token_usage.clone();
|
||||
|
||||
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
|
||||
completion_events.filter_map(move |e| {
|
||||
let tool_to_text = tool_to_text.clone();
|
||||
let last_token_usage = move_last_token_usage.clone();
|
||||
let total_text = total_text.clone();
|
||||
async move {
|
||||
match e {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if tool_use.name.as_ref() == "rewrite_section" =>
|
||||
{
|
||||
Some(Ok(tool_to_text(tool_use)))
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
None
|
||||
}
|
||||
e => {
|
||||
println!("UNEXPECTED EVENT {:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
let language_model_text_stream = LanguageModelTextStream {
|
||||
message_id: message_id,
|
||||
stream: text_stream,
|
||||
last_token_usage,
|
||||
};
|
||||
|
||||
let Some(task) = codegen
|
||||
.update(cx, move |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
telemetry_id,
|
||||
provider_id,
|
||||
api_key,
|
||||
async { Ok(language_model_text_stream) },
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
task.await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
@@ -1419,9 +1060,8 @@ mod tests {
|
||||
};
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{Buffer, Point};
|
||||
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, Point, tree_sitter_rust};
|
||||
use language_model::{LanguageModelRegistry, TokenUsage};
|
||||
use languages::rust_lang;
|
||||
use rand::prelude::*;
|
||||
use settings::SettingsStore;
|
||||
use std::{future, sync::Arc};
|
||||
@@ -1438,7 +1078,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
@@ -1500,7 +1140,7 @@ mod tests {
|
||||
le
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
@@ -1564,7 +1204,7 @@ mod tests {
|
||||
" \n",
|
||||
"}\n" //
|
||||
);
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
@@ -1680,7 +1320,7 @@ mod tests {
|
||||
let x = 0;
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
@@ -1783,7 +1423,7 @@ mod tests {
|
||||
) -> mpsc::UnboundedSender<String> {
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.generation = codegen.handle_stream(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
String::new(),
|
||||
None,
|
||||
@@ -1797,4 +1437,27 @@ mod tests {
|
||||
});
|
||||
chunks_tx
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -387,9 +387,17 @@ impl InlineAssistant {
|
||||
let mut selections = Vec::<Selection<Point>>::new();
|
||||
let mut newest_selection = None;
|
||||
for mut selection in initial_selections {
|
||||
if selection.end == selection.start
|
||||
&& let Some(fold) =
|
||||
snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
|
||||
if selection.end > selection.start {
|
||||
selection.start.column = 0;
|
||||
// If the selection ends at the start of the line, we don't want to include it.
|
||||
if selection.end.column == 0 {
|
||||
selection.end.row -= 1;
|
||||
}
|
||||
selection.end.column = snapshot
|
||||
.buffer_snapshot()
|
||||
.line_len(MultiBufferRow(selection.end.row));
|
||||
} else if let Some(fold) =
|
||||
snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
|
||||
{
|
||||
selection.start = fold.range().start;
|
||||
selection.end = fold.range().end;
|
||||
@@ -416,15 +424,6 @@ impl InlineAssistant {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
selection.start.column = 0;
|
||||
// If the selection ends at the start of the line, we don't want to include it.
|
||||
if selection.end.column == 0 && selection.start.row != selection.end.row {
|
||||
selection.end.row -= 1;
|
||||
}
|
||||
selection.end.column = snapshot
|
||||
.buffer_snapshot()
|
||||
.line_len(MultiBufferRow(selection.end.row));
|
||||
}
|
||||
|
||||
if let Some(prev_selection) = selections.last_mut()
|
||||
@@ -545,15 +544,14 @@ impl InlineAssistant {
|
||||
}
|
||||
}
|
||||
|
||||
let [prompt_block_id, tool_description_block_id, end_block_id] =
|
||||
self.insert_assist_blocks(&editor, &range, &prompt_editor, cx);
|
||||
let [prompt_block_id, end_block_id] =
|
||||
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
||||
|
||||
assists.push((
|
||||
assist_id,
|
||||
range.clone(),
|
||||
prompt_editor,
|
||||
prompt_block_id,
|
||||
tool_description_block_id,
|
||||
end_block_id,
|
||||
));
|
||||
}
|
||||
@@ -572,15 +570,7 @@ impl InlineAssistant {
|
||||
};
|
||||
|
||||
let mut assist_group = InlineAssistGroup::new();
|
||||
for (
|
||||
assist_id,
|
||||
range,
|
||||
prompt_editor,
|
||||
prompt_block_id,
|
||||
tool_description_block_id,
|
||||
end_block_id,
|
||||
) in assists
|
||||
{
|
||||
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||
let codegen = prompt_editor.read(cx).codegen().clone();
|
||||
|
||||
self.assists.insert(
|
||||
@@ -591,7 +581,6 @@ impl InlineAssistant {
|
||||
editor,
|
||||
&prompt_editor,
|
||||
prompt_block_id,
|
||||
tool_description_block_id,
|
||||
end_block_id,
|
||||
range,
|
||||
codegen,
|
||||
@@ -700,7 +689,7 @@ impl InlineAssistant {
|
||||
range: &Range<Anchor>,
|
||||
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
|
||||
cx: &mut App,
|
||||
) -> [CustomBlockId; 3] {
|
||||
) -> [CustomBlockId; 2] {
|
||||
let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
|
||||
prompt_editor
|
||||
.editor
|
||||
@@ -714,14 +703,6 @@ impl InlineAssistant {
|
||||
render: build_assist_editor_renderer(prompt_editor),
|
||||
priority: 0,
|
||||
},
|
||||
// Placeholder for tool description - will be updated dynamically
|
||||
BlockProperties {
|
||||
style: BlockStyle::Flex,
|
||||
placement: BlockPlacement::Below(range.end),
|
||||
height: Some(0),
|
||||
render: Arc::new(|_cx| div().into_any_element()),
|
||||
priority: 0,
|
||||
},
|
||||
BlockProperties {
|
||||
style: BlockStyle::Sticky,
|
||||
placement: BlockPlacement::Below(range.end),
|
||||
@@ -740,7 +721,7 @@ impl InlineAssistant {
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let block_ids = editor.insert_blocks(assist_blocks, None, cx);
|
||||
[block_ids[0], block_ids[1], block_ids[2]]
|
||||
[block_ids[0], block_ids[1]]
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1132,9 +1113,6 @@ impl InlineAssistant {
|
||||
let mut to_remove = decorations.removed_line_block_ids;
|
||||
to_remove.insert(decorations.prompt_block_id);
|
||||
to_remove.insert(decorations.end_block_id);
|
||||
if let Some(tool_description_block_id) = decorations.model_explanation {
|
||||
to_remove.insert(tool_description_block_id);
|
||||
}
|
||||
editor.remove_blocks(to_remove, None, cx);
|
||||
});
|
||||
|
||||
@@ -1455,60 +1433,8 @@ impl InlineAssistant {
|
||||
let old_snapshot = codegen.snapshot(cx);
|
||||
let old_buffer = codegen.old_buffer(cx);
|
||||
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
|
||||
// let model_explanation = codegen.model_explanation(cx);
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
// Update tool description block
|
||||
// if let Some(description) = model_explanation {
|
||||
// if let Some(block_id) = decorations.model_explanation {
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(1),
|
||||
// render: Arc::new({
|
||||
// let description = description.clone();
|
||||
// move |cx| {
|
||||
// div()
|
||||
// .w_full()
|
||||
// .py_1()
|
||||
// .px_2()
|
||||
// .bg(cx.theme().colors().editor_background)
|
||||
// .border_y_1()
|
||||
// .border_color(cx.theme().status().info_border)
|
||||
// .child(
|
||||
// Label::new(description.clone())
|
||||
// .color(Color::Muted)
|
||||
// .size(LabelSize::Small),
|
||||
// )
|
||||
// .into_any_element()
|
||||
// }
|
||||
// }),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
// } else if let Some(block_id) = decorations.model_explanation {
|
||||
// // Hide the block if there's no description
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(0),
|
||||
// render: Arc::new(|_cx| div().into_any_element()),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
|
||||
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
|
||||
editor.remove_blocks(old_blocks, None, cx);
|
||||
|
||||
@@ -1760,7 +1686,6 @@ impl InlineAssist {
|
||||
editor: &Entity<Editor>,
|
||||
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
|
||||
prompt_block_id: CustomBlockId,
|
||||
tool_description_block_id: CustomBlockId,
|
||||
end_block_id: CustomBlockId,
|
||||
range: Range<Anchor>,
|
||||
codegen: Entity<BufferCodegen>,
|
||||
@@ -1775,8 +1700,7 @@ impl InlineAssist {
|
||||
decorations: Some(InlineAssistDecorations {
|
||||
prompt_block_id,
|
||||
prompt_editor: prompt_editor.clone(),
|
||||
removed_line_block_ids: Default::default(),
|
||||
model_explanation: Some(tool_description_block_id),
|
||||
removed_line_block_ids: HashSet::default(),
|
||||
end_block_id,
|
||||
}),
|
||||
range,
|
||||
@@ -1880,7 +1804,6 @@ struct InlineAssistDecorations {
|
||||
prompt_block_id: CustomBlockId,
|
||||
prompt_editor: Entity<PromptEditor<BufferCodegen>>,
|
||||
removed_line_block_ids: HashSet<CustomBlockId>,
|
||||
model_explanation: Option<CustomBlockId>,
|
||||
end_block_id: CustomBlockId,
|
||||
}
|
||||
|
||||
|
||||
@@ -10,11 +10,10 @@ use editor::{
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, Subscription,
|
||||
TextStyle, TextStyleRefinement, WeakEntity, Window,
|
||||
AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
Subscription, TextStyle, WeakEntity, Window,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
@@ -66,7 +65,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
|
||||
const RIGHT_PADDING: Pixels = px(9.);
|
||||
|
||||
let (left_gutter_width, right_padding, explanation) = match &self.mode {
|
||||
let (left_gutter_width, right_padding) = match &self.mode {
|
||||
PromptEditorMode::Buffer {
|
||||
id: _,
|
||||
codegen,
|
||||
@@ -84,23 +83,17 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
|
||||
let right_padding = editor_margins.right + RIGHT_PADDING;
|
||||
|
||||
let explanation = codegen
|
||||
.active_alternative()
|
||||
.read(cx)
|
||||
.model_explanation
|
||||
.clone();
|
||||
|
||||
(left_gutter_width, right_padding, explanation)
|
||||
(left_gutter_width, right_padding)
|
||||
}
|
||||
PromptEditorMode::Terminal { .. } => {
|
||||
// Give the equivalent of the same left-padding that we're using on the right
|
||||
(Pixels::from(40.0), Pixels::from(24.), None)
|
||||
(Pixels::from(40.0), Pixels::from(24.))
|
||||
}
|
||||
};
|
||||
|
||||
let bottom_padding = match &self.mode {
|
||||
PromptEditorMode::Buffer { .. } => rems_from_px(2.0),
|
||||
PromptEditorMode::Terminal { .. } => rems_from_px(4.0),
|
||||
PromptEditorMode::Terminal { .. } => rems_from_px(8.0),
|
||||
};
|
||||
|
||||
buttons.extend(self.render_buttons(window, cx));
|
||||
@@ -118,33 +111,22 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
this.trigger_completion_menu(window, cx);
|
||||
}));
|
||||
|
||||
let markdown = window.use_state(cx, |_, cx| Markdown::new("".into(), None, None, cx));
|
||||
|
||||
if let Some(explanation) = &explanation {
|
||||
markdown.update(cx, |markdown, cx| {
|
||||
markdown.reset(explanation.clone(), cx);
|
||||
});
|
||||
}
|
||||
|
||||
let explanation_label = self
|
||||
.render_markdown(markdown, markdown_style(window, cx))
|
||||
.into_any_element();
|
||||
|
||||
v_flex()
|
||||
.key_context("PromptEditor")
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.block_mouse_except_scroll()
|
||||
.gap_0p5()
|
||||
.border_y_1()
|
||||
.border_color(cx.theme().status().info_border)
|
||||
.size_full()
|
||||
.pt_0p5()
|
||||
.pb(bottom_padding)
|
||||
.pr(right_padding)
|
||||
.gap_0p5()
|
||||
.justify_center()
|
||||
.border_y_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.items_start()
|
||||
.cursor(CursorStyle::Arrow)
|
||||
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
|
||||
this.model_selector
|
||||
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
|
||||
@@ -157,14 +139,14 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.capture_action(cx.listener(Self::cycle_next))
|
||||
.child(
|
||||
WithRemSize::new(ui_font_size)
|
||||
.h_full()
|
||||
.w(left_gutter_width)
|
||||
.flex()
|
||||
.flex_row()
|
||||
.flex_shrink_0()
|
||||
.items_center()
|
||||
.h_full()
|
||||
.w(left_gutter_width)
|
||||
.justify_center()
|
||||
.gap_1()
|
||||
.gap_2()
|
||||
.child(self.render_close_button(cx))
|
||||
.map(|el| {
|
||||
let CodegenStatus::Error(error) = self.codegen_status(cx) else {
|
||||
@@ -195,83 +177,26 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.flex_row()
|
||||
.items_center()
|
||||
.gap_1()
|
||||
.child(add_context_button)
|
||||
.child(self.model_selector.clone())
|
||||
.children(buttons),
|
||||
),
|
||||
),
|
||||
)
|
||||
.when_some(explanation, |this, _| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.size_full()
|
||||
.justify_center()
|
||||
.child(div().w(left_gutter_width + px(6.)))
|
||||
.child(
|
||||
div()
|
||||
.size_full()
|
||||
.min_w_0()
|
||||
.pt(rems_from_px(3.))
|
||||
.pl_0p5()
|
||||
.flex_1()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(explanation_label),
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let colors = cx.theme().colors();
|
||||
let mut text_style = window.text_style();
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
color: Some(colors.text),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style.clone(),
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: colors.element_selection_background,
|
||||
heading_level_styles: Some(HeadingLevelStyles {
|
||||
h1: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.15).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h2: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.1).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h3: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.05).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h4: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(1.).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h5: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.95).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
h6: Some(TextStyleRefinement {
|
||||
font_size: Some(rems(0.875).into()),
|
||||
..Default::default()
|
||||
}),
|
||||
}),
|
||||
inline_code: TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
background_color: Some(colors.editor_foreground.opacity(0.08)),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
.child(
|
||||
WithRemSize::new(ui_font_size)
|
||||
.flex()
|
||||
.flex_row()
|
||||
.items_center()
|
||||
.child(h_flex().flex_shrink_0().w(left_gutter_width))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.pl_1()
|
||||
.items_start()
|
||||
.justify_between()
|
||||
.child(add_context_button)
|
||||
.child(self.model_selector.clone()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -834,10 +759,6 @@ impl<T: 'static> PromptEditor<T> {
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_markdown(&self, markdown: Entity<Markdown>, style: MarkdownStyle) -> MarkdownElement {
|
||||
MarkdownElement::new(markdown, style)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum PromptEditorMode {
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use collections::IndexMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
use gpui::{
|
||||
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
|
||||
};
|
||||
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
|
||||
LanguageModelRegistry,
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
@@ -57,12 +56,12 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all = providers
|
||||
let all: Vec<ModelInfo> = providers
|
||||
.iter()
|
||||
.flat_map(|provider| {
|
||||
provider
|
||||
@@ -70,7 +69,7 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -78,10 +77,26 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
GroupedModels::new(all, recommended)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ModelInfo {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
icon: IconName,
|
||||
icon: ProviderIcon,
|
||||
}
|
||||
|
||||
pub struct LanguageModelPickerDelegate {
|
||||
@@ -91,7 +106,7 @@ pub struct LanguageModelPickerDelegate {
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
_authenticate_all_providers_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
_refresh_models_task: Task<()>,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
@@ -116,24 +131,40 @@ impl LanguageModelPickerDelegate {
|
||||
filtered_entries: entries,
|
||||
get_active_model: Arc::new(get_active_model),
|
||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||
_subscriptions: vec![cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|picker, _, event, window, cx| {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let query = picker.query(cx);
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
// Update matches will automatically drop the previous task
|
||||
// if we get a provider event again
|
||||
picker.update_matches(query, window, cx)
|
||||
}
|
||||
_ => {}
|
||||
_refresh_models_task: {
|
||||
// Create a channel to signal when models need refreshing
|
||||
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
|
||||
|
||||
// Subscribe to registry events and send refresh signals through the channel
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
},
|
||||
)],
|
||||
language_model::Event::AddedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
language_model::Event::RemovedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// Spawn a task that listens for refresh signals and updates the picker
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(()) = refresh_rx.next().await {
|
||||
let result = this.update_in(cx, |picker, window, cx| {
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
picker.refresh(window, cx);
|
||||
});
|
||||
if result.is_err() {
|
||||
// Picker was dropped, exit the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
popover_styles,
|
||||
focus_handle,
|
||||
}
|
||||
@@ -504,11 +535,16 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(model_info.icon)
|
||||
.child(match &model_info.icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
ProviderIcon::Path(icon_path) => {
|
||||
Icon::from_external_svg(icon_path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
}
|
||||
})
|
||||
.child(Label::new(model_info.model.name().0).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
@@ -657,7 +693,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(provider, name)| ModelInfo {
|
||||
model: Arc::new(TestLanguageModel::new(name, provider)),
|
||||
icon: IconName::Ai,
|
||||
icon: ProviderIcon::Name(IconName::Ai),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let editor_clipboard_selections = cx
|
||||
.read_from_clipboard()
|
||||
.and_then(|item| item.entries().first().cloned())
|
||||
.and_then(|entry| match entry {
|
||||
ClipboardEntry::String(text) => {
|
||||
text.metadata_json::<Vec<editor::ClipboardSelection>>()
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let has_file_context = editor_clipboard_selections
|
||||
.as_ref()
|
||||
.is_some_and(|selections| {
|
||||
selections
|
||||
.iter()
|
||||
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
|
||||
});
|
||||
|
||||
if has_file_context {
|
||||
if let Some(clipboard_item) = cx.read_from_clipboard() {
|
||||
if let Some(ClipboardEntry::String(clipboard_text)) =
|
||||
clipboard_item.entries().first()
|
||||
{
|
||||
if let Some(selections) = editor_clipboard_selections {
|
||||
cx.stop_propagation();
|
||||
|
||||
let text = clipboard_text.text();
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let mut current_offset = 0;
|
||||
let weak_editor = cx.entity().downgrade();
|
||||
|
||||
for selection in selections {
|
||||
if let (Some(file_path), Some(line_range)) =
|
||||
(selection.file_path, selection.line_range)
|
||||
{
|
||||
let selected_text =
|
||||
&text[current_offset..current_offset + selection.len];
|
||||
let fence = assistant_slash_commands::codeblock_fence_for_path(
|
||||
file_path.to_str(),
|
||||
Some(line_range.clone()),
|
||||
);
|
||||
let formatted_text = format!("{fence}{selected_text}\n```");
|
||||
|
||||
let insert_point = editor
|
||||
.selections
|
||||
.newest::<Point>(&editor.display_snapshot(cx))
|
||||
.head();
|
||||
let start_row = MultiBufferRow(insert_point.row);
|
||||
|
||||
editor.insert(&formatted_text, window, cx);
|
||||
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let anchor_before = snapshot.anchor_after(insert_point);
|
||||
let anchor_after = editor
|
||||
.selections
|
||||
.newest_anchor()
|
||||
.head()
|
||||
.bias_left(&snapshot);
|
||||
|
||||
editor.insert("\n", window, cx);
|
||||
|
||||
let crease_text = acp_thread::selection_name(
|
||||
Some(file_path.as_ref()),
|
||||
&line_range,
|
||||
);
|
||||
|
||||
let fold_placeholder = quote_selection_fold_placeholder(
|
||||
crease_text,
|
||||
weak_editor.clone(),
|
||||
);
|
||||
let crease = Crease::inline(
|
||||
anchor_before..anchor_after,
|
||||
fold_placeholder,
|
||||
render_quote_selection_output_toggle,
|
||||
|_, _, _, _| Empty.into_any(),
|
||||
);
|
||||
editor.insert_creases(vec![crease], cx);
|
||||
editor.fold_at(start_row, window, cx);
|
||||
|
||||
current_offset += selection.len;
|
||||
if !selection.is_entire_line && current_offset < text.len() {
|
||||
current_offset += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cx.stop_propagation();
|
||||
|
||||
let mut images = if let Some(item) = cx.read_from_clipboard() {
|
||||
@@ -2189,7 +2097,8 @@ impl TextThreadEditor {
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let provider_icon = match active_provider {
|
||||
let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path());
|
||||
let provider_icon_name = match &active_provider {
|
||||
Some(provider) => provider.icon(),
|
||||
None => IconName::Ai,
|
||||
};
|
||||
@@ -2201,6 +2110,16 @@ impl TextThreadEditor {
|
||||
(Color::Muted, IconName::ChevronDown)
|
||||
};
|
||||
|
||||
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
} else {
|
||||
Icon::new(provider_icon_name)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
};
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.language_model_selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
@@ -2208,7 +2127,7 @@ impl TextThreadEditor {
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
|
||||
.child(provider_icon_element)
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.color(color)
|
||||
|
||||
@@ -106,6 +106,9 @@ impl Render for AgentNotification {
|
||||
.font(ui_font)
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_xl()
|
||||
.on_click(cx.listener(|_, _, _, cx| {
|
||||
cx.emit(AgentNotificationEvent::Accepted);
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
.items_start()
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::{Divider, List, ListBulletItem, prelude::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ApiKeysWithProviders {
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
configured_providers: Vec<(ProviderIcon, SharedString)>,
|
||||
}
|
||||
|
||||
impl ApiKeysWithProviders {
|
||||
@@ -26,14 +42,19 @@ impl ApiKeysWithProviders {
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.map(|provider| {
|
||||
(
|
||||
ProviderIcon::from_provider(provider.as_ref()),
|
||||
provider.name().0,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -47,7 +68,14 @@ impl Render for ApiKeysWithProviders {
|
||||
.map(|(icon, name)| {
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
|
||||
.child(match icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(icon_name)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
})
|
||||
.child(Label::new(name))
|
||||
});
|
||||
div()
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
|
||||
pub struct AgentPanelOnboarding {
|
||||
user_store: Entity<UserStore>,
|
||||
client: Arc<Client>,
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
has_configured_providers: bool,
|
||||
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ impl AgentPanelOnboarding {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.configured_providers = Self::compute_available_providers(cx)
|
||||
this.has_configured_providers = Self::has_configured_providers(cx)
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
@@ -38,20 +38,16 @@ impl AgentPanelOnboarding {
|
||||
Self {
|
||||
user_store,
|
||||
client,
|
||||
configured_providers: Self::compute_available_providers(cx),
|
||||
has_configured_providers: Self::has_configured_providers(cx),
|
||||
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn has_configured_providers(cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.collect()
|
||||
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +77,7 @@ impl Render for AgentPanelOnboarding {
|
||||
}),
|
||||
)
|
||||
.map(|this| {
|
||||
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
|
||||
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
|
||||
this
|
||||
} else {
|
||||
this.child(ApiKeysWithoutProviders::new())
|
||||
|
||||
@@ -8,12 +8,10 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
|
||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
|
||||
pub use settings::ModelMode;
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
pub mod batches;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
@@ -467,7 +465,6 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate completion with streaming.
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
@@ -480,101 +477,6 @@ pub async fn stream_completion(
|
||||
.map(|output| output.0)
|
||||
}
|
||||
|
||||
/// Generate completion without streaming.
|
||||
pub async fn non_streaming_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
beta_headers: Option<String>,
|
||||
) -> Result<Response, AnthropicError> {
|
||||
let (mut response, rate_limits) =
|
||||
send_request(client, api_url, api_key, &request, beta_headers).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
|
||||
} else {
|
||||
Err(handle_error_response(response, rate_limits).await)
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: impl Serialize,
|
||||
beta_headers: Option<String>,
|
||||
) -> Result<(http::Response<AsyncBody>, RateLimitInfo), AnthropicError> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("X-Api-Key", api_key.trim())
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(beta_headers) = beta_headers {
|
||||
request_builder = request_builder.header("Anthropic-Beta", beta_headers);
|
||||
}
|
||||
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::from(serialized_request))
|
||||
.map_err(AnthropicError::BuildRequestBody)?;
|
||||
|
||||
let response = client
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
|
||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||
|
||||
Ok((response, rate_limits))
|
||||
}
|
||||
|
||||
async fn handle_error_response(
|
||||
mut response: http::Response<AsyncBody>,
|
||||
rate_limits: RateLimitInfo,
|
||||
) -> AnthropicError {
|
||||
if response.status().as_u16() == 529 {
|
||||
return AnthropicError::ServerOverloaded {
|
||||
retry_after: rate_limits.retry_after,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(retry_after) = rate_limits.retry_after {
|
||||
return AnthropicError::RateLimit { retry_after };
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
let read_result = response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse);
|
||||
|
||||
if let Err(err) = read_result {
|
||||
return err;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<Event>(&body) {
|
||||
Ok(Event::Error { error }) => AnthropicError::ApiError(error),
|
||||
Ok(_) | Err(_) => AnthropicError::HttpResponseError {
|
||||
status_code: response.status(),
|
||||
message: body,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// An individual rate limit.
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimit {
|
||||
@@ -678,10 +580,30 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
base: request,
|
||||
stream: true,
|
||||
};
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
|
||||
let (response, rate_limits) =
|
||||
send_request(client, api_url, api_key, &request, beta_headers).await?;
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("X-Api-Key", api_key.trim())
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(beta_headers) = beta_headers {
|
||||
request_builder = request_builder.header("Anthropic-Beta", beta_headers);
|
||||
}
|
||||
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::from(serialized_request))
|
||||
.map_err(AnthropicError::BuildRequestBody)?;
|
||||
|
||||
let mut response = client
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
let stream = reader
|
||||
@@ -700,8 +622,27 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
})
|
||||
.boxed();
|
||||
Ok((stream, Some(rate_limits)))
|
||||
} else if response.status().as_u16() == 529 {
|
||||
Err(AnthropicError::ServerOverloaded {
|
||||
retry_after: rate_limits.retry_after,
|
||||
})
|
||||
} else if let Some(retry_after) = rate_limits.retry_after {
|
||||
Err(AnthropicError::RateLimit { retry_after })
|
||||
} else {
|
||||
Err(handle_error_response(response, rate_limits).await)
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
match serde_json::from_str::<Event>(&body) {
|
||||
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
||||
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
|
||||
status_code: response.status(),
|
||||
message: body,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use futures::AsyncReadExt;
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BatchRequest {
|
||||
pub custom_id: String,
|
||||
pub params: Request,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CreateBatchRequest {
|
||||
pub requests: Vec<BatchRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct MessageBatchRequestCounts {
|
||||
pub processing: u64,
|
||||
pub succeeded: u64,
|
||||
pub errored: u64,
|
||||
pub canceled: u64,
|
||||
pub expired: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct MessageBatch {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub batch_type: String,
|
||||
pub processing_status: String,
|
||||
pub request_counts: MessageBatchRequestCounts,
|
||||
pub ended_at: Option<String>,
|
||||
pub created_at: String,
|
||||
pub expires_at: String,
|
||||
pub archived_at: Option<String>,
|
||||
pub cancel_initiated_at: Option<String>,
|
||||
pub results_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum BatchResult {
|
||||
#[serde(rename = "succeeded")]
|
||||
Succeeded { message: Response },
|
||||
#[serde(rename = "errored")]
|
||||
Errored { error: ApiError },
|
||||
#[serde(rename = "canceled")]
|
||||
Canceled,
|
||||
#[serde(rename = "expired")]
|
||||
Expired,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BatchIndividualResponse {
|
||||
pub custom_id: String,
|
||||
pub result: BatchResult,
|
||||
}
|
||||
|
||||
pub async fn create_batch(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CreateBatchRequest,
|
||||
) -> Result<MessageBatch, AnthropicError> {
|
||||
let uri = format!("{api_url}/v1/messages/batches");
|
||||
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("X-Api-Key", api_key.trim())
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
|
||||
let http_request = request_builder
|
||||
.body(AsyncBody::from(serialized_request))
|
||||
.map_err(AnthropicError::BuildRequestBody)?;
|
||||
|
||||
let mut response = client
|
||||
.send(http_request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
|
||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
|
||||
} else {
|
||||
Err(crate::handle_error_response(response, rate_limits).await)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn retrieve_batch(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
message_batch_id: &str,
|
||||
) -> Result<MessageBatch, AnthropicError> {
|
||||
let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
|
||||
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("X-Api-Key", api_key.trim());
|
||||
|
||||
let http_request = request_builder
|
||||
.body(AsyncBody::default())
|
||||
.map_err(AnthropicError::BuildRequestBody)?;
|
||||
|
||||
let mut response = client
|
||||
.send(http_request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
|
||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
|
||||
} else {
|
||||
Err(crate::handle_error_response(response, rate_limits).await)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn retrieve_batch_results(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
message_batch_id: &str,
|
||||
) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
|
||||
let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
|
||||
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("X-Api-Key", api_key.trim());
|
||||
|
||||
let http_request = request_builder
|
||||
.body(AsyncBody::default())
|
||||
.map_err(AnthropicError::BuildRequestBody)?;
|
||||
|
||||
let mut response = client
|
||||
.send(http_request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
|
||||
let rate_limits = RateLimitInfo::from_headers(response.headers());
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for line in body.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let result: BatchIndividualResponse =
|
||||
serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
} else {
|
||||
Err(crate::handle_error_response(response, rate_limits).await)
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,7 @@ use fs::{Fs, RenameOptions};
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
|
||||
Task, WeakEntity,
|
||||
Task,
|
||||
};
|
||||
use itertools::Itertools as _;
|
||||
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
|
||||
@@ -688,7 +688,7 @@ pub struct TextThread {
|
||||
_subscriptions: Vec<Subscription>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
project: Option<WeakEntity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
completion_mode: agent_settings::CompletionMode,
|
||||
}
|
||||
@@ -708,7 +708,7 @@ impl EventEmitter<TextThreadEvent> for TextThread {}
|
||||
impl TextThread {
|
||||
pub fn local(
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
project: Option<WeakEntity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
@@ -742,7 +742,7 @@ impl TextThread {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
project: Option<WeakEntity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -873,7 +873,7 @@ impl TextThread {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
slash_commands: Arc<SlashCommandWorkingSet>,
|
||||
project: Option<WeakEntity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -1167,6 +1167,10 @@ impl TextThread {
|
||||
self.language_registry.clone()
|
||||
}
|
||||
|
||||
pub fn project(&self) -> Option<Entity<Project>> {
|
||||
self.project.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_builder(&self) -> Arc<PromptBuilder> {
|
||||
self.prompt_builder.clone()
|
||||
}
|
||||
@@ -2963,7 +2967,7 @@ impl TextThread {
|
||||
}
|
||||
|
||||
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) {
|
||||
let Some(project) = self.project.as_ref().and_then(|project| project.upgrade()) else {
|
||||
let Some(project) = &self.project else {
|
||||
return;
|
||||
};
|
||||
project.read(cx).user_store().update(cx, |user_store, cx| {
|
||||
|
||||
@@ -51,7 +51,7 @@ pub struct TextThreadStore {
|
||||
telemetry: Arc<Telemetry>,
|
||||
_watch_updates: Task<Option<()>>,
|
||||
client: Arc<Client>,
|
||||
project: WeakEntity<Project>,
|
||||
project: Entity<Project>,
|
||||
project_is_shared: bool,
|
||||
client_subscription: Option<client::Subscription>,
|
||||
_project_subscriptions: Vec<gpui::Subscription>,
|
||||
@@ -119,10 +119,10 @@ impl TextThreadStore {
|
||||
],
|
||||
project_is_shared: false,
|
||||
client: project.read(cx).client(),
|
||||
project: project.downgrade(),
|
||||
project: project.clone(),
|
||||
prompt_builder,
|
||||
};
|
||||
this.handle_project_shared(cx);
|
||||
this.handle_project_shared(project.clone(), cx);
|
||||
this.synchronize_contexts(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
@@ -146,7 +146,7 @@ impl TextThreadStore {
|
||||
telemetry: project.read(cx).client().telemetry().clone(),
|
||||
_watch_updates: Task::ready(None),
|
||||
client: project.read(cx).client(),
|
||||
project: project.downgrade(),
|
||||
project,
|
||||
project_is_shared: false,
|
||||
client_subscription: None,
|
||||
_project_subscriptions: Default::default(),
|
||||
@@ -180,10 +180,8 @@ impl TextThreadStore {
|
||||
) -> Result<proto::OpenContextResponse> {
|
||||
let context_id = TextThreadId::from_proto(envelope.payload.context_id);
|
||||
let operations = this.update(&mut cx, |this, cx| {
|
||||
let project = this.project.upgrade().context("project not found")?;
|
||||
|
||||
anyhow::ensure!(
|
||||
!project.read(cx).is_via_collab(),
|
||||
!this.project.read(cx).is_via_collab(),
|
||||
"only the host contexts can be opened"
|
||||
);
|
||||
|
||||
@@ -213,9 +211,8 @@ impl TextThreadStore {
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::CreateContextResponse> {
|
||||
let (context_id, operations) = this.update(&mut cx, |this, cx| {
|
||||
let project = this.project.upgrade().context("project not found")?;
|
||||
anyhow::ensure!(
|
||||
!project.read(cx).is_via_collab(),
|
||||
!this.project.read(cx).is_via_collab(),
|
||||
"can only create contexts as the host"
|
||||
);
|
||||
|
||||
@@ -258,9 +255,8 @@ impl TextThreadStore {
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<proto::SynchronizeContextsResponse> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let project = this.project.upgrade().context("project not found")?;
|
||||
anyhow::ensure!(
|
||||
!project.read(cx).is_via_collab(),
|
||||
!this.project.read(cx).is_via_collab(),
|
||||
"only the host can synchronize contexts"
|
||||
);
|
||||
|
||||
@@ -297,12 +293,8 @@ impl TextThreadStore {
|
||||
})?
|
||||
}
|
||||
|
||||
fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let is_shared = project.read(cx).is_shared();
|
||||
fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
|
||||
let is_shared = self.project.read(cx).is_shared();
|
||||
let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
|
||||
if is_shared == was_shared {
|
||||
return;
|
||||
@@ -317,7 +309,7 @@ impl TextThreadStore {
|
||||
false
|
||||
}
|
||||
});
|
||||
let remote_id = project.read(cx).remote_id().unwrap();
|
||||
let remote_id = self.project.read(cx).remote_id().unwrap();
|
||||
self.client_subscription = self
|
||||
.client
|
||||
.subscribe_to_entity(remote_id)
|
||||
@@ -331,13 +323,13 @@ impl TextThreadStore {
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::RemoteIdChanged(_) => {
|
||||
self.handle_project_shared(cx);
|
||||
self.handle_project_shared(project, cx);
|
||||
}
|
||||
project::Event::Reshared => {
|
||||
self.advertise_contexts(cx);
|
||||
@@ -390,10 +382,7 @@ impl TextThreadStore {
|
||||
}
|
||||
|
||||
pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return Task::ready(Err(anyhow::anyhow!("project was dropped")));
|
||||
};
|
||||
let project = project.read(cx);
|
||||
let project = self.project.read(cx);
|
||||
let Some(project_id) = project.remote_id() else {
|
||||
return Task::ready(Err(anyhow::anyhow!("project was not remote")));
|
||||
};
|
||||
@@ -552,10 +541,7 @@ impl TextThreadStore {
|
||||
text_thread_id: TextThreadId,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Entity<TextThread>>> {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return Task::ready(Err(anyhow::anyhow!("project was dropped")));
|
||||
};
|
||||
let project = project.read(cx);
|
||||
let project = self.project.read(cx);
|
||||
let Some(project_id) = project.remote_id() else {
|
||||
return Task::ready(Err(anyhow::anyhow!("project was not remote")));
|
||||
};
|
||||
@@ -632,10 +618,7 @@ impl TextThreadStore {
|
||||
event: &TextThreadEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let Some(project_id) = project.read(cx).remote_id() else {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -669,14 +652,12 @@ impl TextThreadStore {
|
||||
}
|
||||
|
||||
fn advertise_contexts(&self, cx: &App) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let Some(project_id) = project.read(cx).remote_id() else {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
// For now, only the host can advertise their open contexts.
|
||||
if project.read(cx).is_via_collab() {
|
||||
if self.project.read(cx).is_via_collab() {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -708,10 +689,7 @@ impl TextThreadStore {
|
||||
}
|
||||
|
||||
fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let Some(project_id) = project.read(cx).remote_id() else {
|
||||
let Some(project_id) = self.project.read(cx).remote_id() else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -850,10 +828,7 @@ impl TextThreadStore {
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let context_server_store = project.read(cx).context_server_store();
|
||||
let context_server_store = self.project.read(cx).context_server_store();
|
||||
cx.subscribe(&context_server_store, Self::handle_context_server_event)
|
||||
.detach();
|
||||
|
||||
|
||||
@@ -31,10 +31,18 @@ pub struct PredictEditsRequest {
|
||||
/// Within `signatures`
|
||||
pub excerpt_parent: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub related_files: Vec<RelatedFile>,
|
||||
pub included_files: Vec<IncludedFile>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub signatures: Vec<Signature>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub referenced_declarations: Vec<ReferencedDeclaration>,
|
||||
pub events: Vec<Arc<Event>>,
|
||||
#[serde(default)]
|
||||
pub can_collect_data: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub diagnostic_groups: Vec<DiagnosticGroup>,
|
||||
#[serde(skip_serializing_if = "is_default", default)]
|
||||
pub diagnostic_groups_truncated: bool,
|
||||
/// Info about the git repository state, only present when can_collect_data is true.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub git_info: Option<PredictEditsGitInfo>,
|
||||
@@ -50,7 +58,7 @@ pub struct PredictEditsRequest {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RelatedFile {
|
||||
pub struct IncludedFile {
|
||||
pub path: Arc<Path>,
|
||||
pub max_row: Line,
|
||||
pub excerpts: Vec<Excerpt>,
|
||||
@@ -64,9 +72,11 @@ pub struct Excerpt {
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum PromptFormat {
|
||||
/// XML old_tex/new_text
|
||||
MarkedExcerpt,
|
||||
LabeledSections,
|
||||
NumLinesUniDiff,
|
||||
OldTextNewText,
|
||||
/// Prompt format intended for use via edit_prediction_cli
|
||||
/// Prompt format intended for use via zeta_cli
|
||||
OnlySnippets,
|
||||
/// One-sentence instructions used in fine-tuned models
|
||||
Minimal,
|
||||
@@ -77,7 +87,7 @@ pub enum PromptFormat {
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
|
||||
}
|
||||
|
||||
impl Default for PromptFormat {
|
||||
@@ -95,7 +105,10 @@ impl PromptFormat {
|
||||
impl std::fmt::Display for PromptFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
|
||||
PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
|
||||
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
|
||||
PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
|
||||
PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
|
||||
PromptFormat::Minimal => write!(f, "Minimal"),
|
||||
PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
|
||||
@@ -165,6 +178,67 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Signature {
|
||||
pub text: String,
|
||||
pub text_is_truncated: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub parent_index: Option<usize>,
|
||||
/// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
|
||||
/// file is implicitly the file that contains the descendant declaration or excerpt.
|
||||
pub range: Range<Line>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReferencedDeclaration {
|
||||
pub path: Arc<Path>,
|
||||
pub text: String,
|
||||
pub text_is_truncated: bool,
|
||||
/// Range of `text` within file, possibly truncated according to `text_is_truncated`
|
||||
pub range: Range<Line>,
|
||||
/// Range within `text`
|
||||
pub signature_range: Range<usize>,
|
||||
/// Index within `signatures`.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub parent_index: Option<usize>,
|
||||
pub score_components: DeclarationScoreComponents,
|
||||
pub signature_score: f32,
|
||||
pub declaration_score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeclarationScoreComponents {
|
||||
pub is_same_file: bool,
|
||||
pub is_referenced_nearby: bool,
|
||||
pub is_referenced_in_breadcrumb: bool,
|
||||
pub reference_count: usize,
|
||||
pub same_file_declaration_count: usize,
|
||||
pub declaration_count: usize,
|
||||
pub reference_line_distance: u32,
|
||||
pub declaration_line_distance: u32,
|
||||
pub excerpt_vs_item_jaccard: f32,
|
||||
pub excerpt_vs_signature_jaccard: f32,
|
||||
pub adjacent_vs_item_jaccard: f32,
|
||||
pub adjacent_vs_signature_jaccard: f32,
|
||||
pub excerpt_vs_item_weighted_overlap: f32,
|
||||
pub excerpt_vs_signature_weighted_overlap: f32,
|
||||
pub adjacent_vs_item_weighted_overlap: f32,
|
||||
pub adjacent_vs_signature_weighted_overlap: f32,
|
||||
pub path_import_match_count: usize,
|
||||
pub wildcard_path_import_match_count: usize,
|
||||
pub import_similarity: f32,
|
||||
pub max_import_similarity: f32,
|
||||
pub normalized_import_similarity: f32,
|
||||
pub wildcard_import_similarity: f32,
|
||||
pub normalized_wildcard_import_similarity: f32,
|
||||
pub included_by_others: usize,
|
||||
pub includes_others: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsResponse {
|
||||
pub request_id: Uuid,
|
||||
@@ -188,6 +262,10 @@ pub struct Edit {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
|
||||
*value == T::default()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
|
||||
pub struct Point {
|
||||
pub line: Line,
|
||||
|
||||
@@ -15,4 +15,9 @@ path = "src/cloud_zeta2_prompt.rs"
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
indoc.workspace = true
|
||||
ordered-float.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
use anyhow::Result;
|
||||
//! Zeta2 prompt planning and generation code shared with cloud.
|
||||
pub mod retrieval_prompt;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
|
||||
self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat,
|
||||
ReferencedDeclaration,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use serde::Serialize;
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
|
||||
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
|
||||
|
||||
@@ -16,6 +24,69 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
|
||||
|
||||
// TODO: use constants for markers?
|
||||
const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
|
||||
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
|
||||
|
||||
The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
|
||||
|
||||
Other code is provided for context, and `…` indicates when code has been skipped.
|
||||
|
||||
## Edit History
|
||||
|
||||
"};
|
||||
|
||||
const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
|
||||
|
||||
Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
|
||||
|
||||
The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
|
||||
|
||||
Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
|
||||
|
||||
<|current_section|>
|
||||
for i in 0..16 {
|
||||
println!("{i}");
|
||||
}
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
|
||||
# Instructions
|
||||
|
||||
You are an edit prediction agent in a code editor.
|
||||
Your job is to predict the next edit that the user will make,
|
||||
based on their last few edits and their current cursor location.
|
||||
|
||||
## Output Format
|
||||
|
||||
You must briefly explain your understanding of the user's goal, in one
|
||||
or two sentences, and then specify their next edit in the form of a
|
||||
unified diff, like this:
|
||||
|
||||
```
|
||||
--- a/src/myapp/cli.py
|
||||
+++ b/src/myapp/cli.py
|
||||
@@ ... @@
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
+from constants import LOG_LEVEL_WARNING
|
||||
@@ ... @@
|
||||
config.headless()
|
||||
config.set_interactive(false)
|
||||
-config.set_log_level(LOG_L)
|
||||
+config.set_log_level(LOG_LEVEL_WARNING)
|
||||
config.set_use_color(True)
|
||||
```
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
|
||||
|
||||
@@ -23,6 +94,20 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
|
||||
"#};
|
||||
|
||||
const UNIFIED_DIFF_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
Analyze the edit history and the files, then provide the unified diff for your predicted edits.
|
||||
Do not include the cursor marker in your output.
|
||||
Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
|
||||
Do not include line numbers in the hunk headers, use `@@ ... @@`.
|
||||
Removed lines begin with `-`.
|
||||
Added lines begin with `+`.
|
||||
Context lines begin with an extra space.
|
||||
Context and removed lines are used to match the target edit location, so make sure to include enough of them
|
||||
to uniquely identify it amongst all excerpts of code provided.
|
||||
"};
|
||||
|
||||
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
@@ -79,25 +164,49 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
|
||||
Remember that the edits in the edit history have already been applied.
|
||||
"#};
|
||||
|
||||
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
|
||||
pub fn build_prompt(
|
||||
request: &predict_edits_v3::PredictEditsRequest,
|
||||
) -> Result<(String, SectionLabels)> {
|
||||
let mut section_labels = Default::default();
|
||||
|
||||
let prompt_data = PromptData {
|
||||
events: request.events.clone(),
|
||||
cursor_point: request.cursor_point,
|
||||
cursor_path: request.excerpt_path.clone(),
|
||||
included_files: request.related_files.clone(),
|
||||
included_files: request.included_files.clone(),
|
||||
};
|
||||
match request.prompt_format {
|
||||
PromptFormat::MinimalQwen => {
|
||||
return Ok(MinimalQwenPrompt.render(&prompt_data));
|
||||
return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels));
|
||||
}
|
||||
PromptFormat::SeedCoder1120 => {
|
||||
return Ok(SeedCoder1120Prompt.render(&prompt_data));
|
||||
return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels));
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let insertions = match request.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
|
||||
let mut insertions = match request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt => vec![
|
||||
(
|
||||
Point {
|
||||
line: request.excerpt_line_range.start,
|
||||
column: 0,
|
||||
},
|
||||
EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
|
||||
),
|
||||
(request.cursor_point, CURSOR_MARKER),
|
||||
(
|
||||
Point {
|
||||
line: request.excerpt_line_range.end,
|
||||
column: 0,
|
||||
},
|
||||
EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
|
||||
),
|
||||
],
|
||||
PromptFormat::LabeledSections
|
||||
| PromptFormat::NumLinesUniDiff
|
||||
| PromptFormat::Minimal
|
||||
| PromptFormat::OldTextNewText => {
|
||||
vec![(request.cursor_point, CURSOR_MARKER)]
|
||||
}
|
||||
PromptFormat::OnlySnippets => vec![],
|
||||
@@ -106,6 +215,9 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
|
||||
};
|
||||
|
||||
let mut prompt = match request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OnlySnippets => String::new(),
|
||||
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
|
||||
@@ -135,7 +247,7 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
|
||||
You can only edit exactly this part of the file.
|
||||
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
|
||||
"},
|
||||
PromptFormat::OldTextNewText => indoc! {"
|
||||
PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
Here is some excerpts of code that you should take into account to predict the next edit.
|
||||
@@ -151,51 +263,64 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
|
||||
|
||||
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
|
||||
"},
|
||||
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
indoc! {"
|
||||
_ => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history have been applied.
|
||||
"}
|
||||
}
|
||||
"},
|
||||
};
|
||||
|
||||
prompt.push_str(excerpts_preamble);
|
||||
prompt.push('\n');
|
||||
|
||||
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
|
||||
for related_file in &request.related_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
|
||||
let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?;
|
||||
section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?;
|
||||
} else {
|
||||
if request.prompt_format == PromptFormat::LabeledSections {
|
||||
anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
|
||||
}
|
||||
|
||||
let include_line_numbers = matches!(
|
||||
request.prompt_format,
|
||||
PromptFormat::NumLinesUniDiff | PromptFormat::Minimal
|
||||
);
|
||||
for related_file in &request.included_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match request.prompt_format {
|
||||
PromptFormat::NumLinesUniDiff => {
|
||||
prompt.push_str(UNIFIED_DIFF_REMINDER);
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
|
||||
}
|
||||
@@ -205,7 +330,7 @@ pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<S
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
Ok((prompt, section_labels))
|
||||
}
|
||||
|
||||
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
|
||||
@@ -319,11 +444,476 @@ pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>])
|
||||
writeln!(output, "`````\n").unwrap();
|
||||
}
|
||||
|
||||
pub struct SyntaxBasedPrompt<'a> {
|
||||
request: &'a predict_edits_v3::PredictEditsRequest,
|
||||
/// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
|
||||
/// `to_prompt_string`.
|
||||
snippets: Vec<PlannedSnippet<'a>>,
|
||||
budget_used: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PlannedSnippet<'a> {
|
||||
path: Arc<Path>,
|
||||
range: Range<Line>,
|
||||
text: &'a str,
|
||||
// TODO: Indicate this in the output
|
||||
#[allow(dead_code)]
|
||||
text_is_truncated: bool,
|
||||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize)]
|
||||
pub struct SectionLabels {
|
||||
pub excerpt_index: usize,
|
||||
pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
|
||||
}
|
||||
|
||||
impl<'a> SyntaxBasedPrompt<'a> {
|
||||
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
|
||||
///
|
||||
/// Initializes a priority queue by populating it with each snippet, finding the
|
||||
/// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
|
||||
/// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
|
||||
/// the cost of upgrade.
|
||||
///
|
||||
/// TODO: Implement an early halting condition. One option might be to have another priority
|
||||
/// queue where the score is the size, and update it accordingly. Another option might be to
|
||||
/// have some simpler heuristic like bailing after N failed insertions, or based on how much
|
||||
/// budget is left.
|
||||
///
|
||||
/// TODO: Has the current known sources of imprecision:
|
||||
///
|
||||
/// * Does not consider snippet overlap when ranking. For example, it might add a field to the
|
||||
/// plan even though the containing struct is already included.
|
||||
///
|
||||
/// * Does not consider cost of signatures when ranking snippets - this is tricky since
|
||||
/// signatures may be shared by multiple snippets.
|
||||
///
|
||||
/// * Does not include file paths / other text when considering max_bytes.
|
||||
pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
|
||||
let mut this = Self {
|
||||
request,
|
||||
snippets: Vec::new(),
|
||||
budget_used: request.excerpt.len(),
|
||||
};
|
||||
let mut included_parents = FxHashSet::default();
|
||||
let additional_parents = this.additional_parent_signatures(
|
||||
&request.excerpt_path,
|
||||
request.excerpt_parent,
|
||||
&included_parents,
|
||||
)?;
|
||||
this.add_parents(&mut included_parents, additional_parents);
|
||||
|
||||
let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
|
||||
|
||||
if this.budget_used > max_bytes {
|
||||
return Err(anyhow!(
|
||||
"Excerpt + signatures size of {} already exceeds budget of {}",
|
||||
this.budget_used,
|
||||
max_bytes
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct QueueEntry {
|
||||
score_density: OrderedFloat<f32>,
|
||||
declaration_index: usize,
|
||||
style: DeclarationStyle,
|
||||
}
|
||||
|
||||
// Initialize priority queue with the best score for each snippet.
|
||||
let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
|
||||
for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
|
||||
let (style, score_density) = DeclarationStyle::iter()
|
||||
.map(|style| {
|
||||
(
|
||||
style,
|
||||
OrderedFloat(declaration_score_density(&declaration, style)),
|
||||
)
|
||||
})
|
||||
.max_by_key(|(_, score_density)| *score_density)
|
||||
.unwrap();
|
||||
queue.push(QueueEntry {
|
||||
score_density,
|
||||
declaration_index,
|
||||
style,
|
||||
});
|
||||
}
|
||||
|
||||
// Knapsack selection loop
|
||||
while let Some(queue_entry) = queue.pop() {
|
||||
let Some(declaration) = request
|
||||
.referenced_declarations
|
||||
.get(queue_entry.declaration_index)
|
||||
else {
|
||||
return Err(anyhow!(
|
||||
"Invalid declaration index {}",
|
||||
queue_entry.declaration_index
|
||||
));
|
||||
};
|
||||
|
||||
let mut additional_bytes = declaration_size(declaration, queue_entry.style);
|
||||
if this.budget_used + additional_bytes > max_bytes {
|
||||
continue;
|
||||
}
|
||||
|
||||
let additional_parents = this.additional_parent_signatures(
|
||||
&declaration.path,
|
||||
declaration.parent_index,
|
||||
&mut included_parents,
|
||||
)?;
|
||||
additional_bytes += additional_parents
|
||||
.iter()
|
||||
.map(|(_, snippet)| snippet.text.len())
|
||||
.sum::<usize>();
|
||||
if this.budget_used + additional_bytes > max_bytes {
|
||||
continue;
|
||||
}
|
||||
|
||||
this.budget_used += additional_bytes;
|
||||
this.add_parents(&mut included_parents, additional_parents);
|
||||
let planned_snippet = match queue_entry.style {
|
||||
DeclarationStyle::Signature => {
|
||||
let Some(text) = declaration.text.get(declaration.signature_range.clone())
|
||||
else {
|
||||
return Err(anyhow!(
|
||||
"Invalid declaration signature_range {:?} with text.len() = {}",
|
||||
declaration.signature_range,
|
||||
declaration.text.len()
|
||||
));
|
||||
};
|
||||
let signature_start_line = declaration.range.start
|
||||
+ Line(
|
||||
declaration.text[..declaration.signature_range.start]
|
||||
.lines()
|
||||
.count() as u32,
|
||||
);
|
||||
let signature_end_line = signature_start_line
|
||||
+ Line(
|
||||
declaration.text
|
||||
[declaration.signature_range.start..declaration.signature_range.end]
|
||||
.lines()
|
||||
.count() as u32,
|
||||
);
|
||||
let range = signature_start_line..signature_end_line;
|
||||
|
||||
PlannedSnippet {
|
||||
path: declaration.path.clone(),
|
||||
range,
|
||||
text,
|
||||
text_is_truncated: declaration.text_is_truncated,
|
||||
}
|
||||
}
|
||||
DeclarationStyle::Declaration => PlannedSnippet {
|
||||
path: declaration.path.clone(),
|
||||
range: declaration.range.clone(),
|
||||
text: &declaration.text,
|
||||
text_is_truncated: declaration.text_is_truncated,
|
||||
},
|
||||
};
|
||||
this.snippets.push(planned_snippet);
|
||||
|
||||
// When a Signature is consumed, insert an entry for Definition style.
|
||||
if queue_entry.style == DeclarationStyle::Signature {
|
||||
let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_size =
|
||||
declaration_size(&declaration, DeclarationStyle::Declaration);
|
||||
let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_score =
|
||||
declaration_score(&declaration, DeclarationStyle::Declaration);
|
||||
|
||||
let score_diff = declaration_score - signature_score;
|
||||
let size_diff = declaration_size.saturating_sub(signature_size);
|
||||
if score_diff > 0.0001 && size_diff > 0 {
|
||||
queue.push(QueueEntry {
|
||||
declaration_index: queue_entry.declaration_index,
|
||||
score_density: OrderedFloat(score_diff / (size_diff as f32)),
|
||||
style: DeclarationStyle::Declaration,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(this)
|
||||
}
|
||||
|
||||
fn add_parents(
|
||||
&mut self,
|
||||
included_parents: &mut FxHashSet<usize>,
|
||||
snippets: Vec<(usize, PlannedSnippet<'a>)>,
|
||||
) {
|
||||
for (parent_index, snippet) in snippets {
|
||||
included_parents.insert(parent_index);
|
||||
self.budget_used += snippet.text.len();
|
||||
self.snippets.push(snippet);
|
||||
}
|
||||
}
|
||||
|
||||
fn additional_parent_signatures(
|
||||
&self,
|
||||
path: &Arc<Path>,
|
||||
parent_index: Option<usize>,
|
||||
included_parents: &FxHashSet<usize>,
|
||||
) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
|
||||
let mut results = Vec::new();
|
||||
self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn additional_parent_signatures_impl(
|
||||
&self,
|
||||
path: &Arc<Path>,
|
||||
parent_index: Option<usize>,
|
||||
included_parents: &FxHashSet<usize>,
|
||||
results: &mut Vec<(usize, PlannedSnippet<'a>)>,
|
||||
) -> Result<()> {
|
||||
let Some(parent_index) = parent_index else {
|
||||
return Ok(());
|
||||
};
|
||||
if included_parents.contains(&parent_index) {
|
||||
return Ok(());
|
||||
}
|
||||
let Some(parent_signature) = self.request.signatures.get(parent_index) else {
|
||||
return Err(anyhow!("Invalid parent index {}", parent_index));
|
||||
};
|
||||
results.push((
|
||||
parent_index,
|
||||
PlannedSnippet {
|
||||
path: path.clone(),
|
||||
range: parent_signature.range.clone(),
|
||||
text: &parent_signature.text,
|
||||
text_is_truncated: parent_signature.text_is_truncated,
|
||||
},
|
||||
));
|
||||
self.additional_parent_signatures_impl(
|
||||
path,
|
||||
parent_signature.parent_index,
|
||||
included_parents,
|
||||
results,
|
||||
)
|
||||
}
|
||||
|
||||
/// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
|
||||
/// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
|
||||
/// chunks.
|
||||
pub fn write(
|
||||
&'a self,
|
||||
excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
|
||||
prompt: &mut String,
|
||||
) -> Result<SectionLabels> {
|
||||
let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
|
||||
FxHashMap::default();
|
||||
for snippet in &self.snippets {
|
||||
file_to_snippets
|
||||
.entry(&snippet.path)
|
||||
.or_default()
|
||||
.push(snippet);
|
||||
}
|
||||
|
||||
// Reorder so that file with cursor comes last
|
||||
let mut file_snippets = Vec::new();
|
||||
let mut excerpt_file_snippets = Vec::new();
|
||||
for (file_path, snippets) in file_to_snippets {
|
||||
if file_path == self.request.excerpt_path.as_ref() {
|
||||
excerpt_file_snippets = snippets;
|
||||
} else {
|
||||
file_snippets.push((file_path, snippets, false));
|
||||
}
|
||||
}
|
||||
let excerpt_snippet = PlannedSnippet {
|
||||
path: self.request.excerpt_path.clone(),
|
||||
range: self.request.excerpt_line_range.clone(),
|
||||
text: &self.request.excerpt,
|
||||
text_is_truncated: false,
|
||||
};
|
||||
excerpt_file_snippets.push(&excerpt_snippet);
|
||||
file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
|
||||
|
||||
let section_labels =
|
||||
self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?;
|
||||
|
||||
Ok(section_labels)
|
||||
}
|
||||
|
||||
fn push_file_snippets(
|
||||
&self,
|
||||
output: &mut String,
|
||||
excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
|
||||
file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
|
||||
) -> Result<SectionLabels> {
|
||||
let mut section_ranges = Vec::new();
|
||||
let mut excerpt_index = None;
|
||||
|
||||
for (file_path, mut snippets, is_excerpt_file) in file_snippets {
|
||||
snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
|
||||
|
||||
// TODO: What if the snippets get expanded too large to be editable?
|
||||
let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
|
||||
let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
|
||||
for snippet in snippets {
|
||||
if let Some((_, current_snippet_range)) = current_snippet.as_mut()
|
||||
&& snippet.range.start <= current_snippet_range.end
|
||||
{
|
||||
current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
|
||||
continue;
|
||||
}
|
||||
if let Some(current_snippet) = current_snippet.take() {
|
||||
disjoint_snippets.push(current_snippet);
|
||||
}
|
||||
current_snippet = Some((snippet, snippet.range.clone()));
|
||||
}
|
||||
if let Some(current_snippet) = current_snippet.take() {
|
||||
disjoint_snippets.push(current_snippet);
|
||||
}
|
||||
|
||||
writeln!(output, "`````path={}", file_path.display()).ok();
|
||||
let mut skipped_last_snippet = false;
|
||||
for (snippet, range) in disjoint_snippets {
|
||||
let section_index = section_ranges.len();
|
||||
|
||||
match self.request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt
|
||||
| PromptFormat::OnlySnippets
|
||||
| PromptFormat::OldTextNewText
|
||||
| PromptFormat::Minimal
|
||||
| PromptFormat::NumLinesUniDiff => {
|
||||
if range.start.0 > 0 && !skipped_last_snippet {
|
||||
output.push_str("…\n");
|
||||
}
|
||||
}
|
||||
PromptFormat::LabeledSections => {
|
||||
if is_excerpt_file
|
||||
&& range.start <= self.request.excerpt_line_range.start
|
||||
&& range.end >= self.request.excerpt_line_range.end
|
||||
{
|
||||
writeln!(output, "<|current_section|>").ok();
|
||||
} else {
|
||||
writeln!(output, "<|section_{}|>", section_index).ok();
|
||||
}
|
||||
}
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
}
|
||||
|
||||
let push_full_snippet = |output: &mut String| {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
for (i, line) in snippet.text.lines().enumerate() {
|
||||
writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
|
||||
}
|
||||
} else {
|
||||
output.push_str(&snippet.text);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
if is_excerpt_file {
|
||||
if self.request.prompt_format == PromptFormat::OnlySnippets {
|
||||
if range.start >= self.request.excerpt_line_range.start
|
||||
&& range.end <= self.request.excerpt_line_range.end
|
||||
{
|
||||
skipped_last_snippet = true;
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
output.push_str(snippet.text);
|
||||
}
|
||||
} else if !excerpt_file_insertions.is_empty() {
|
||||
let lines = snippet.text.lines().collect::<Vec<_>>();
|
||||
let push_line = |output: &mut String, line_ix: usize| {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
|
||||
}
|
||||
anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
|
||||
};
|
||||
let mut last_line_ix = 0;
|
||||
let mut insertion_ix = 0;
|
||||
while insertion_ix < excerpt_file_insertions.len() {
|
||||
let (point, insertion) = &excerpt_file_insertions[insertion_ix];
|
||||
let found = point.line >= range.start && point.line <= range.end;
|
||||
if found {
|
||||
excerpt_index = Some(section_index);
|
||||
let insertion_line_ix = (point.line.0 - range.start.0) as usize;
|
||||
for line_ix in last_line_ix..insertion_line_ix {
|
||||
push_line(output, line_ix)?;
|
||||
}
|
||||
if let Some(next_line) = lines.get(insertion_line_ix) {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
write!(
|
||||
output,
|
||||
"{}|",
|
||||
insertion_line_ix as u32 + range.start.0 + 1
|
||||
)?
|
||||
}
|
||||
output.push_str(&next_line[..point.column as usize]);
|
||||
output.push_str(insertion);
|
||||
writeln!(output, "{}", &next_line[point.column as usize..])?;
|
||||
} else {
|
||||
writeln!(output, "{}", insertion)?;
|
||||
}
|
||||
last_line_ix = insertion_line_ix + 1;
|
||||
excerpt_file_insertions.remove(insertion_ix);
|
||||
continue;
|
||||
}
|
||||
insertion_ix += 1;
|
||||
}
|
||||
skipped_last_snippet = false;
|
||||
for line_ix in last_line_ix..lines.len() {
|
||||
push_line(output, line_ix)?;
|
||||
}
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
push_full_snippet(output)?;
|
||||
}
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
push_full_snippet(output)?;
|
||||
}
|
||||
|
||||
section_ranges.push((snippet.path.clone(), range));
|
||||
}
|
||||
|
||||
output.push_str("`````\n\n");
|
||||
}
|
||||
|
||||
Ok(SectionLabels {
|
||||
// TODO: Clean this up
|
||||
excerpt_index: match self.request.prompt_format {
|
||||
PromptFormat::OnlySnippets => 0,
|
||||
_ => excerpt_index.context("bug: no snippet found for excerpt")?,
|
||||
},
|
||||
section_ranges,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
declaration_score(declaration, style) / declaration_size(declaration, style) as f32
|
||||
}
|
||||
|
||||
fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
match style {
|
||||
DeclarationStyle::Signature => declaration.signature_score,
|
||||
DeclarationStyle::Declaration => declaration.declaration_score,
|
||||
}
|
||||
}
|
||||
|
||||
fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
|
||||
match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
}
|
||||
}
|
||||
|
||||
struct PromptData {
|
||||
events: Vec<Arc<Event>>,
|
||||
cursor_point: Point,
|
||||
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
|
||||
included_files: Vec<RelatedFile>,
|
||||
included_files: Vec<IncludedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -461,7 +1051,7 @@ impl SeedCoder1120Prompt {
|
||||
context
|
||||
}
|
||||
|
||||
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
|
||||
fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String {
|
||||
let mut buf = String::new();
|
||||
const FIM_SUFFIX: &str = "<[fim-suffix]>";
|
||||
const FIM_PREFIX: &str = "<[fim-prefix]>";
|
||||
|
||||
244
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs
Normal file
244
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
|
||||
use indoc::indoc;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::{push_events, write_codeblock};
|
||||
|
||||
pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
|
||||
let mut prompt = SEARCH_INSTRUCTIONS.to_string();
|
||||
|
||||
if !request.events.is_empty() {
|
||||
writeln!(&mut prompt, "\n## User Edits\n\n")?;
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
writeln!(&mut prompt, "## Cursor context\n")?;
|
||||
write_codeblock(
|
||||
&request.excerpt_path,
|
||||
&[Excerpt {
|
||||
start_line: request.excerpt_line_range.start,
|
||||
text: request.excerpt.into(),
|
||||
}],
|
||||
&[],
|
||||
request.cursor_file_max_row,
|
||||
true,
|
||||
&mut prompt,
|
||||
);
|
||||
|
||||
writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Search for relevant code
|
||||
///
|
||||
/// For the best results, run multiple queries at once with a single invocation of this tool.
|
||||
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct SearchToolInput {
|
||||
/// An array of queries to run for gathering context relevant to the next prediction
|
||||
#[schemars(length(max = 3))]
|
||||
#[serde(deserialize_with = "deserialize_queries")]
|
||||
pub queries: Box<[SearchToolQuery]>,
|
||||
}
|
||||
|
||||
fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum QueryCollection {
|
||||
Array(Box<[SearchToolQuery]>),
|
||||
DoubleArray(Box<[Box<[SearchToolQuery]>]>),
|
||||
Single(SearchToolQuery),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MaybeDoubleEncoded {
|
||||
SingleEncoded(QueryCollection),
|
||||
DoubleEncoded(String),
|
||||
}
|
||||
|
||||
let result = MaybeDoubleEncoded::deserialize(deserializer)?;
|
||||
|
||||
let normalized = match result {
|
||||
MaybeDoubleEncoded::SingleEncoded(value) => value,
|
||||
MaybeDoubleEncoded::DoubleEncoded(value) => {
|
||||
serde_json::from_str(&value).map_err(D::Error::custom)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(match normalized {
|
||||
QueryCollection::Array(items) => items,
|
||||
QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
|
||||
QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search for relevant code by path, syntax hierarchy, and content.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
|
||||
pub struct SearchToolQuery {
|
||||
/// 1. A glob pattern to match file paths in the codebase to search in.
|
||||
pub glob: String,
|
||||
/// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
|
||||
///
|
||||
/// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
|
||||
///
|
||||
/// Example: Searching for a `User` class
|
||||
/// ["class\s+User"]
|
||||
///
|
||||
/// Example: Searching for a `get_full_name` method under a `User` class
|
||||
/// ["class\s+User", "def\sget_full_name"]
|
||||
///
|
||||
/// Skip this field to match on content alone.
|
||||
#[schemars(length(max = 3))]
|
||||
#[serde(default)]
|
||||
pub syntax_node: Vec<String>,
|
||||
/// 3. An optional regular expression to match the final content that should appear in the results.
|
||||
///
|
||||
/// - Content will be matched within all lines of the matched syntax nodes.
|
||||
/// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
|
||||
/// - If no syntax node regexes are provided, the content will be matched within the entire file.
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
pub const TOOL_NAME: &str = "search";
|
||||
|
||||
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are part of an edit prediction system in a code editor.
|
||||
Your role is to search for code that will serve as context for predicting the next edit.
|
||||
|
||||
- Analyze the user's recent edits and current cursor context
|
||||
- Use the `search` tool to find code that is relevant for predicting the next edit
|
||||
- Focus on finding:
|
||||
- Code patterns that might need similar changes based on the recent edits
|
||||
- Functions, variables, types, and constants referenced in the current cursor context
|
||||
- Related implementations, usages, or dependencies that may require consistent updates
|
||||
- How items defined in the cursor excerpt are used or altered
|
||||
- You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
|
||||
- Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
|
||||
- Avoid using wildcard globs if you already know the file path of the content you're looking for
|
||||
"#};
|
||||
|
||||
const TOOL_USE_REMINDER: &str = indoc! {"
|
||||
--
|
||||
Analyze the user's intent in one to two sentences, then call the `search` tool.
|
||||
"};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_queries() {
|
||||
let single_query_json = indoc! {r#"{
|
||||
"queries": {
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 1);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
|
||||
let flat_json = indoc! {r#"{
|
||||
"queries": [
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 2);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(flat_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(flat_input.queries[1].content, None);
|
||||
|
||||
let nested_json = indoc! {r#"{
|
||||
"queries": [
|
||||
[
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
]
|
||||
}"#};
|
||||
|
||||
let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
|
||||
|
||||
assert_eq!(nested_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(nested_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(nested_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(nested_input.queries[1].content, None);
|
||||
|
||||
let double_encoded_queries = serde_json::to_string(&json!({
|
||||
"queries": serde_json::to_string(&json!([
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
])).unwrap()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let double_encoded_input: SearchToolInput =
|
||||
serde_json::from_str(&double_encoded_queries).unwrap();
|
||||
|
||||
assert_eq!(double_encoded_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(
|
||||
double_encoded_input.queries[0].content,
|
||||
Some("assert".to_string())
|
||||
);
|
||||
assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(double_encoded_input.queries[1].content, None);
|
||||
|
||||
// ### ERROR Switching from var declarations to lexical declarations [RUN 073]
|
||||
// invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ path = "src/codestral.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
|
||||
use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::{App, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
@@ -43,17 +43,17 @@ impl CurrentCompletion {
|
||||
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
|
||||
/// Returns None if the user's edits conflict with the predicted edits.
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodestralEditPredictionDelegate {
|
||||
pub struct CodestralCompletionProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
pending_request: Option<Task<Result<()>>>,
|
||||
current_completion: Option<CurrentCompletion>,
|
||||
}
|
||||
|
||||
impl CodestralEditPredictionDelegate {
|
||||
impl CodestralCompletionProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
@@ -165,7 +165,7 @@ impl CodestralEditPredictionDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionDelegate for CodestralEditPredictionDelegate {
|
||||
impl EditPredictionProvider for CodestralCompletionProvider {
|
||||
fn name() -> &'static str {
|
||||
"codestral"
|
||||
}
|
||||
@@ -174,7 +174,7 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate {
|
||||
"Codestral"
|
||||
}
|
||||
|
||||
fn show_predictions_in_menu() -> bool {
|
||||
fn show_completions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -239,6 +239,7 @@ impl EditPredictionDelegate for CodestralEditPredictionDelegate {
|
||||
cursor_point,
|
||||
&snapshot,
|
||||
&EXCERPT_OPTIONS,
|
||||
None,
|
||||
)
|
||||
.context("Line containing cursor doesn't fit in excerpt max bytes")?;
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ tokio = { workspace = true, features = ["full"] }
|
||||
toml.workspace = true
|
||||
tower = "0.4"
|
||||
tower-http = { workspace = true, features = ["trace"] }
|
||||
tracing.workspace = true
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -121,8 +121,6 @@ CREATE TABLE "project_repositories" (
|
||||
"merge_message" VARCHAR,
|
||||
"branch_summary" VARCHAR,
|
||||
"head_commit_details" VARCHAR,
|
||||
"remote_upstream_url" VARCHAR,
|
||||
"remote_origin_url" VARCHAR,
|
||||
PRIMARY KEY (project_id, id)
|
||||
);
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
ALTER TABLE "project_repositories" ADD COLUMN "remote_upstream_url" VARCHAR;
|
||||
ALTER TABLE "project_repositories" ADD COLUMN "remote_origin_url" VARCHAR;
|
||||
@@ -362,8 +362,6 @@ impl Database {
|
||||
entry_ids: ActiveValue::set("[]".into()),
|
||||
head_commit_details: ActiveValue::set(None),
|
||||
merge_message: ActiveValue::set(None),
|
||||
remote_upstream_url: ActiveValue::set(None),
|
||||
remote_origin_url: ActiveValue::set(None),
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -513,8 +511,6 @@ impl Database {
|
||||
serde_json::to_string(&update.current_merge_conflicts).unwrap(),
|
||||
)),
|
||||
merge_message: ActiveValue::set(update.merge_message.clone()),
|
||||
remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()),
|
||||
remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()),
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
@@ -1009,8 +1005,6 @@ impl Database {
|
||||
is_last_update: true,
|
||||
merge_message: db_repository_entry.merge_message,
|
||||
stash_entries: Vec::new(),
|
||||
remote_upstream_url: db_repository_entry.remote_upstream_url.clone(),
|
||||
remote_origin_url: db_repository_entry.remote_origin_url.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -796,8 +796,6 @@ impl Database {
|
||||
is_last_update: true,
|
||||
merge_message: db_repository.merge_message,
|
||||
stash_entries: Vec::new(),
|
||||
remote_upstream_url: db_repository.remote_upstream_url.clone(),
|
||||
remote_origin_url: db_repository.remote_origin_url.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ pub struct Model {
|
||||
pub branch_summary: Option<String>,
|
||||
// A JSON object representing the current Head commit values
|
||||
pub head_commit_details: Option<String>,
|
||||
pub remote_upstream_url: Option<String>,
|
||||
pub remote_origin_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use call::Room;
|
||||
use client::ChannelId;
|
||||
use gpui::{Entity, TestAppContext};
|
||||
@@ -16,6 +18,7 @@ mod randomized_test_helpers;
|
||||
mod remote_editing_collaboration_tests;
|
||||
mod test_server;
|
||||
|
||||
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
pub use randomized_test_helpers::{
|
||||
RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan,
|
||||
};
|
||||
@@ -48,3 +51,17 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
|
||||
fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
|
||||
cx.read(|cx| room.read(cx).channel_id())
|
||||
}
|
||||
|
||||
fn rust_lang() -> Arc<Language> {
|
||||
Arc::new(Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
|
||||
use crate::{
|
||||
rpc::RECONNECT_TIMEOUT,
|
||||
tests::{TestServer, rust_lang},
|
||||
};
|
||||
use call::ActiveCall;
|
||||
use editor::{
|
||||
DocumentColorsRenderMode, Editor, FETCH_COLORS_DEBOUNCE_TIMEOUT, MultiBufferOffset, RowInfo,
|
||||
@@ -20,7 +23,7 @@ use gpui::{
|
||||
App, Rgba, SharedString, TestAppContext, UpdateGlobal, VisualContext, VisualTestContext,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language::{FakeLspAdapter, rust_lang};
|
||||
use language::FakeLspAdapter;
|
||||
use lsp::LSP_REQUEST_TIMEOUT;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{
|
||||
@@ -3515,6 +3518,7 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
|
||||
.into_iter()
|
||||
.map(|(sha, message)| (sha.parse().unwrap(), message.into()))
|
||||
.collect(),
|
||||
remote_url: Some("git@github.com:zed-industries/zed.git".to_string()),
|
||||
};
|
||||
client_a.fs().set_blame_for_repo(
|
||||
Path::new(path!("/my-repo/.git")),
|
||||
@@ -3599,6 +3603,10 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
|
||||
for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() {
|
||||
let details = blame.details_for_entry(*buffer, entry).unwrap();
|
||||
assert_eq!(details.message, format!("message for idx-{}", idx));
|
||||
assert_eq!(
|
||||
details.permalink.unwrap().to_string(),
|
||||
format!("https://github.com/zed-industries/zed/commit/{}", entry.sha)
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{
|
||||
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
|
||||
tests::{
|
||||
RoomParticipants, TestClient, TestServer, channel_id, following_tests::join_channel,
|
||||
room_participants,
|
||||
room_participants, rust_lang,
|
||||
},
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
@@ -26,7 +26,7 @@ use language::{
|
||||
Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig,
|
||||
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
|
||||
language_settings::{Formatter, FormatterList},
|
||||
rust_lang, tree_sitter_rust, tree_sitter_typescript,
|
||||
tree_sitter_rust, tree_sitter_typescript,
|
||||
};
|
||||
use lsp::{LanguageServerId, OneOf};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
@@ -33,7 +33,7 @@ fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod copilot_chat;
|
||||
mod copilot_edit_prediction_delegate;
|
||||
mod copilot_completion_provider;
|
||||
pub mod copilot_responses;
|
||||
pub mod request;
|
||||
mod sign_in;
|
||||
@@ -46,7 +46,7 @@ use util::rel_path::RelPath;
|
||||
use util::{ResultExt, fs::remove_matching};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
|
||||
pub use crate::copilot_completion_provider::CopilotCompletionProvider;
|
||||
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
|
||||
|
||||
actions!(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{Completion, Copilot};
|
||||
use anyhow::Result;
|
||||
use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
|
||||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
|
||||
use settings::Settings;
|
||||
@@ -8,7 +8,7 @@ use std::{path::Path, time::Duration};
|
||||
|
||||
pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||
|
||||
pub struct CopilotEditPredictionDelegate {
|
||||
pub struct CopilotCompletionProvider {
|
||||
cycled: bool,
|
||||
buffer_id: Option<EntityId>,
|
||||
completions: Vec<Completion>,
|
||||
@@ -19,7 +19,7 @@ pub struct CopilotEditPredictionDelegate {
|
||||
copilot: Entity<Copilot>,
|
||||
}
|
||||
|
||||
impl CopilotEditPredictionDelegate {
|
||||
impl CopilotCompletionProvider {
|
||||
pub fn new(copilot: Entity<Copilot>) -> Self {
|
||||
Self {
|
||||
cycled: false,
|
||||
@@ -47,7 +47,7 @@ impl CopilotEditPredictionDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionDelegate for CopilotEditPredictionDelegate {
|
||||
impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
fn name() -> &'static str {
|
||||
"copilot"
|
||||
}
|
||||
@@ -56,7 +56,7 @@ impl EditPredictionDelegate for CopilotEditPredictionDelegate {
|
||||
"Copilot"
|
||||
}
|
||||
|
||||
fn show_predictions_in_menu() -> bool {
|
||||
fn show_completions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -546,7 +546,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -670,7 +670,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -753,7 +753,7 @@ mod tests {
|
||||
window.focus(&editor.focus_handle(cx));
|
||||
})
|
||||
.unwrap();
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
editor
|
||||
.update(cx, |editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
@@ -848,7 +848,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -1000,7 +1000,7 @@ mod tests {
|
||||
window.focus(&editor.focus_handle(cx))
|
||||
})
|
||||
.unwrap();
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
editor
|
||||
.update(cx, |editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
@@ -37,7 +37,6 @@ dap_adapters = { workspace = true, optional = true }
|
||||
db.workspace = true
|
||||
debugger_tools.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
file_icons.workspace = true
|
||||
futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
@@ -83,7 +82,6 @@ dap_adapters = { workspace = true, features = ["test-support"] }
|
||||
debugger_tools = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-go.workspace = true
|
||||
unindent.workspace = true
|
||||
|
||||
@@ -15,11 +15,10 @@ use dap::adapters::DebugAdapterName;
|
||||
use dap::{DapRegistry, StartDebuggingRequestArguments};
|
||||
use dap::{client::SessionId, debugger_settings::DebuggerSettings};
|
||||
use editor::{Editor, MultiBufferOffset, ToPoint};
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
|
||||
use gpui::{
|
||||
Action, App, AsyncWindowContext, ClipboardItem, Context, Corner, DismissEvent, Entity,
|
||||
EntityId, EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point,
|
||||
Subscription, Task, WeakEntity, anchored, deferred,
|
||||
Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId,
|
||||
EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task,
|
||||
WeakEntity, anchored, deferred,
|
||||
};
|
||||
|
||||
use itertools::Itertools as _;
|
||||
@@ -32,9 +31,7 @@ use settings::Settings;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use task::{DebugScenario, TaskContext};
|
||||
use tree_sitter::{Query, StreamingIterator as _};
|
||||
use ui::{
|
||||
ContextMenu, Divider, PopoverMenu, PopoverMenuHandle, SplitButton, Tab, Tooltip, prelude::*,
|
||||
};
|
||||
use ui::{ContextMenu, Divider, PopoverMenuHandle, Tab, Tooltip, prelude::*};
|
||||
use util::rel_path::RelPath;
|
||||
use util::{ResultExt, debug_panic, maybe};
|
||||
use workspace::SplitDirection;
|
||||
@@ -45,12 +42,6 @@ use workspace::{
|
||||
};
|
||||
use zed_actions::ToggleFocus;
|
||||
|
||||
pub struct DebuggerHistoryFeatureFlag;
|
||||
|
||||
impl FeatureFlag for DebuggerHistoryFeatureFlag {
|
||||
const NAME: &'static str = "debugger-history";
|
||||
}
|
||||
|
||||
const DEBUG_PANEL_KEY: &str = "DebugPanel";
|
||||
|
||||
pub struct DebugPanel {
|
||||
@@ -293,7 +284,7 @@ impl DebugPanel {
|
||||
}
|
||||
});
|
||||
|
||||
session.update(cx, |session, _| match &mut session.state {
|
||||
session.update(cx, |session, _| match &mut session.mode {
|
||||
SessionState::Booting(state_task) => {
|
||||
*state_task = Some(boot_task);
|
||||
}
|
||||
@@ -671,12 +662,6 @@ impl DebugPanel {
|
||||
)
|
||||
};
|
||||
|
||||
let thread_status = active_session
|
||||
.as_ref()
|
||||
.map(|session| session.read(cx).running_state())
|
||||
.and_then(|state| state.read(cx).thread_status(cx))
|
||||
.unwrap_or(project::debugger::session::ThreadStatus::Exited);
|
||||
|
||||
Some(
|
||||
div.w_full()
|
||||
.py_1()
|
||||
@@ -694,6 +679,10 @@ impl DebugPanel {
|
||||
.as_ref()
|
||||
.map(|session| session.read(cx).running_state()),
|
||||
|this, running_state| {
|
||||
let thread_status =
|
||||
running_state.read(cx).thread_status(cx).unwrap_or(
|
||||
project::debugger::session::ThreadStatus::Exited,
|
||||
);
|
||||
let capabilities = running_state.read(cx).capabilities(cx);
|
||||
let supports_detach =
|
||||
running_state.read(cx).session().read(cx).is_attached();
|
||||
@@ -882,53 +871,36 @@ impl DebugPanel {
|
||||
}
|
||||
}),
|
||||
)
|
||||
.when(supports_detach, |div| {
|
||||
div.child(
|
||||
IconButton::new(
|
||||
"debug-disconnect",
|
||||
IconName::DebugDetach,
|
||||
)
|
||||
.disabled(
|
||||
thread_status != ThreadStatus::Stopped
|
||||
&& thread_status != ThreadStatus::Running,
|
||||
)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(window.listener_for(
|
||||
running_state,
|
||||
|this, _, _, cx| {
|
||||
this.detach_client(cx);
|
||||
},
|
||||
))
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Detach",
|
||||
&Detach,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when(
|
||||
cx.has_flag::<DebuggerHistoryFeatureFlag>(),
|
||||
|this| {
|
||||
this.child(Divider::vertical()).child(
|
||||
SplitButton::new(
|
||||
self.render_history_button(
|
||||
&running_state,
|
||||
thread_status,
|
||||
window,
|
||||
),
|
||||
self.render_history_toggle_button(
|
||||
thread_status,
|
||||
&running_state,
|
||||
)
|
||||
.into_any_element(),
|
||||
supports_detach,
|
||||
|div| {
|
||||
div.child(
|
||||
IconButton::new(
|
||||
"debug-disconnect",
|
||||
IconName::DebugDetach,
|
||||
)
|
||||
.style(ui::SplitButtonStyle::Outlined),
|
||||
.disabled(
|
||||
thread_status != ThreadStatus::Stopped
|
||||
&& thread_status != ThreadStatus::Running,
|
||||
)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(window.listener_for(
|
||||
running_state,
|
||||
|this, _, _, cx| {
|
||||
this.detach_client(cx);
|
||||
},
|
||||
))
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Detach",
|
||||
&Detach,
|
||||
&focus_handle,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}),
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -1345,97 +1317,6 @@ impl DebugPanel {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn render_history_button(
|
||||
&self,
|
||||
running_state: &Entity<RunningState>,
|
||||
thread_status: ThreadStatus,
|
||||
window: &mut Window,
|
||||
) -> IconButton {
|
||||
IconButton::new("debug-back-in-history", IconName::HistoryRerun)
|
||||
.icon_size(IconSize::Small)
|
||||
.on_click(window.listener_for(running_state, |this, _, _window, cx| {
|
||||
this.session().update(cx, |session, cx| {
|
||||
let ix = session
|
||||
.active_snapshot_index()
|
||||
.unwrap_or_else(|| session.historic_snapshots().len());
|
||||
|
||||
session.select_historic_snapshot(Some(ix.saturating_sub(1)), cx);
|
||||
})
|
||||
}))
|
||||
.disabled(
|
||||
thread_status == ThreadStatus::Running || thread_status == ThreadStatus::Stepping,
|
||||
)
|
||||
}
|
||||
|
||||
fn render_history_toggle_button(
|
||||
&self,
|
||||
thread_status: ThreadStatus,
|
||||
running_state: &Entity<RunningState>,
|
||||
) -> impl IntoElement {
|
||||
PopoverMenu::new("debug-back-in-history-menu")
|
||||
.trigger(
|
||||
ui::ButtonLike::new_rounded_right("debug-back-in-history-menu-trigger")
|
||||
.layer(ui::ElevationIndex::ModalSurface)
|
||||
.size(ui::ButtonSize::None)
|
||||
.child(
|
||||
div()
|
||||
.px_1()
|
||||
.child(Icon::new(IconName::ChevronDown).size(IconSize::XSmall)),
|
||||
)
|
||||
.disabled(
|
||||
thread_status == ThreadStatus::Running
|
||||
|| thread_status == ThreadStatus::Stepping,
|
||||
),
|
||||
)
|
||||
.menu({
|
||||
let running_state = running_state.clone();
|
||||
move |window, cx| {
|
||||
let handler =
|
||||
|ix: Option<usize>, running_state: Entity<RunningState>, cx: &mut App| {
|
||||
running_state.update(cx, |state, cx| {
|
||||
state.session().update(cx, |session, cx| {
|
||||
session.select_historic_snapshot(ix, cx);
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
let running_state = running_state.clone();
|
||||
Some(ContextMenu::build(
|
||||
window,
|
||||
cx,
|
||||
move |mut context_menu, _window, cx| {
|
||||
let history = running_state
|
||||
.read(cx)
|
||||
.session()
|
||||
.read(cx)
|
||||
.historic_snapshots();
|
||||
|
||||
context_menu = context_menu.entry("Current State", None, {
|
||||
let running_state = running_state.clone();
|
||||
move |_window, cx| {
|
||||
handler(None, running_state.clone(), cx);
|
||||
}
|
||||
});
|
||||
context_menu = context_menu.separator();
|
||||
|
||||
for (ix, _) in history.iter().enumerate().rev() {
|
||||
context_menu =
|
||||
context_menu.entry(format!("history-{}", ix + 1), None, {
|
||||
let running_state = running_state.clone();
|
||||
move |_window, cx| {
|
||||
handler(Some(ix), running_state.clone(), cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
context_menu
|
||||
},
|
||||
))
|
||||
}
|
||||
})
|
||||
.anchor(Corner::TopRight)
|
||||
}
|
||||
}
|
||||
|
||||
async fn register_session_inner(
|
||||
|
||||
@@ -387,7 +387,7 @@ pub fn init(cx: &mut App) {
|
||||
window.on_action(
|
||||
TypeId::of::<editor::actions::EvaluateSelectedText>(),
|
||||
move |_, _, window, cx| {
|
||||
let status = maybe!({
|
||||
maybe!({
|
||||
let text = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let range = editor
|
||||
@@ -411,13 +411,7 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
state.session().update(cx, |session, cx| {
|
||||
session
|
||||
.evaluate(
|
||||
text,
|
||||
Some(dap::EvaluateArgumentsContext::Repl),
|
||||
stack_id,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.evaluate(text, None, stack_id, None, cx)
|
||||
.detach();
|
||||
});
|
||||
});
|
||||
@@ -425,9 +419,6 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
Some(())
|
||||
});
|
||||
if status.is_some() {
|
||||
cx.stop_propagation();
|
||||
}
|
||||
},
|
||||
);
|
||||
})
|
||||
|
||||
@@ -881,6 +881,7 @@ impl ConfigureMode {
|
||||
.label("Stop on Entry")
|
||||
.label_position(SwitchLabelPosition::Start)
|
||||
.label_size(LabelSize::Default)
|
||||
.color(ui::SwitchColor::Accent)
|
||||
.on_click({
|
||||
let this = cx.weak_entity();
|
||||
move |state, _, cx| {
|
||||
@@ -1022,7 +1023,7 @@ impl DebugDelegate {
|
||||
Some(TaskSourceKind::Lsp { language_name, .. }) => {
|
||||
Some(format!("LSP: {language_name}"))
|
||||
}
|
||||
Some(TaskSourceKind::Language { name }) => Some(format!("Language: {name}")),
|
||||
Some(TaskSourceKind::Language { name }) => Some(format!("Lang: {name}")),
|
||||
_ => context.clone().and_then(|ctx| {
|
||||
ctx.task_context
|
||||
.task_variables
|
||||
|
||||
@@ -1743,7 +1743,7 @@ impl RunningState {
|
||||
|
||||
let is_building = self.session.update(cx, |session, cx| {
|
||||
session.shutdown(cx).detach();
|
||||
matches!(session.state, session::SessionState::Booting(_))
|
||||
matches!(session.mode, session::SessionState::Booting(_))
|
||||
});
|
||||
|
||||
if is_building {
|
||||
|
||||
@@ -17,9 +17,7 @@ impl LoadedSourceList {
|
||||
let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
|
||||
|
||||
let _subscription = cx.subscribe(&session, |this, _, event, cx| match event {
|
||||
SessionEvent::Stopped(_)
|
||||
| SessionEvent::HistoricSnapshotSelected
|
||||
| SessionEvent::LoadedSources => {
|
||||
SessionEvent::Stopped(_) | SessionEvent::LoadedSources => {
|
||||
this.invalidate = true;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -32,9 +32,7 @@ impl ModuleList {
|
||||
let focus_handle = cx.focus_handle();
|
||||
|
||||
let _subscription = cx.subscribe(&session, |this, _, event, cx| match event {
|
||||
SessionEvent::Stopped(_)
|
||||
| SessionEvent::HistoricSnapshotSelected
|
||||
| SessionEvent::Modules => {
|
||||
SessionEvent::Stopped(_) | SessionEvent::Modules => {
|
||||
if this._rebuild_task.is_some() {
|
||||
this.schedule_rebuild(cx);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use dap::StackFrameId;
|
||||
use dap::adapters::DebugAdapterName;
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
use gpui::{
|
||||
Action, AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState,
|
||||
@@ -21,7 +20,7 @@ use project::debugger::breakpoint_store::ActiveStackFrame;
|
||||
use project::debugger::session::{Session, SessionEvent, StackFrame, ThreadStatus};
|
||||
use project::{ProjectItem, ProjectPath};
|
||||
use ui::{Tooltip, WithScrollbar, prelude::*};
|
||||
use workspace::{ItemHandle, Workspace, WorkspaceId};
|
||||
use workspace::{ItemHandle, Workspace};
|
||||
|
||||
use super::RunningState;
|
||||
|
||||
@@ -59,14 +58,6 @@ impl From<StackFrameFilter> for String {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn stack_frame_filter_key(
|
||||
adapter_name: &DebugAdapterName,
|
||||
workspace_id: WorkspaceId,
|
||||
) -> String {
|
||||
let database_id: i64 = workspace_id.into();
|
||||
format!("stack-frame-list-filter-{}-{}", adapter_name.0, database_id)
|
||||
}
|
||||
|
||||
pub struct StackFrameList {
|
||||
focus_handle: FocusHandle,
|
||||
_subscription: Subscription,
|
||||
@@ -106,9 +97,7 @@ impl StackFrameList {
|
||||
SessionEvent::Threads => {
|
||||
this.schedule_refresh(false, window, cx);
|
||||
}
|
||||
SessionEvent::Stopped(..)
|
||||
| SessionEvent::StackTrace
|
||||
| SessionEvent::HistoricSnapshotSelected => {
|
||||
SessionEvent::Stopped(..) | SessionEvent::StackTrace => {
|
||||
this.schedule_refresh(true, window, cx);
|
||||
}
|
||||
_ => {}
|
||||
@@ -116,18 +105,14 @@ impl StackFrameList {
|
||||
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
|
||||
|
||||
let list_filter = workspace
|
||||
.read_with(cx, |workspace, _| workspace.database_id())
|
||||
let list_filter = KEY_VALUE_STORE
|
||||
.read_kvp(&format!(
|
||||
"stack-frame-list-filter-{}",
|
||||
session.read(cx).adapter().0
|
||||
))
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|database_id| {
|
||||
let key = stack_frame_filter_key(&session.read(cx).adapter(), database_id);
|
||||
KEY_VALUE_STORE
|
||||
.read_kvp(&key)
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(StackFrameFilter::from_str_or_default)
|
||||
})
|
||||
.map(StackFrameFilter::from_str_or_default)
|
||||
.unwrap_or(StackFrameFilter::All);
|
||||
|
||||
let mut this = Self {
|
||||
@@ -240,6 +225,7 @@ impl StackFrameList {
|
||||
}
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.build_entries(select_first, window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
@@ -820,8 +806,15 @@ impl StackFrameList {
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
let key = stack_frame_filter_key(&self.session.read(cx).adapter(), database_id);
|
||||
let save_task = KEY_VALUE_STORE.write_kvp(key, self.list_filter.into());
|
||||
let database_id: i64 = database_id.into();
|
||||
let save_task = KEY_VALUE_STORE.write_kvp(
|
||||
format!(
|
||||
"stack-frame-list-filter-{}-{}",
|
||||
self.session.read(cx).adapter().0,
|
||||
database_id,
|
||||
),
|
||||
self.list_filter.into(),
|
||||
);
|
||||
cx.background_spawn(save_task).detach();
|
||||
}
|
||||
|
||||
|
||||
@@ -217,12 +217,6 @@ impl VariableList {
|
||||
let _subscriptions = vec![
|
||||
cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events),
|
||||
cx.subscribe(&session, |this, _, event, cx| match event {
|
||||
SessionEvent::HistoricSnapshotSelected => {
|
||||
this.selection.take();
|
||||
this.edited_path.take();
|
||||
this.selected_stack_frame_id.take();
|
||||
this.build_entries(cx);
|
||||
}
|
||||
SessionEvent::Stopped(_) => {
|
||||
this.selection.take();
|
||||
this.edited_path.take();
|
||||
@@ -231,6 +225,7 @@ impl VariableList {
|
||||
SessionEvent::Variables | SessionEvent::Watchers => {
|
||||
this.build_entries(cx);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}),
|
||||
cx.on_focus_out(&focus_handle, window, |this, _, _, cx| {
|
||||
|
||||
@@ -4,7 +4,7 @@ use dap::{Scope, StackFrame, Variable, requests::Variables};
|
||||
use editor::{Editor, EditorMode, MultiBuffer};
|
||||
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
|
||||
use language::{
|
||||
Language, LanguageConfig, LanguageMatcher, rust_lang, tree_sitter_python,
|
||||
Language, LanguageConfig, LanguageMatcher, tree_sitter_python, tree_sitter_rust,
|
||||
tree_sitter_typescript,
|
||||
};
|
||||
use project::{FakeFs, Project};
|
||||
@@ -224,7 +224,7 @@ fn main() {
|
||||
.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_language(Some(rust_lang()), cx);
|
||||
buffer.set_language(Some(Arc::new(rust_lang())), cx);
|
||||
});
|
||||
|
||||
let (editor, cx) = cx.add_window_view(|window, cx| {
|
||||
@@ -1521,6 +1521,23 @@ fn main() {
|
||||
});
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
let debug_variables_query = include_str!("../../../languages/src/rust/debugger.scm");
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_python_inline_values(executor: BackgroundExecutor, cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
@@ -1842,23 +1859,21 @@ fn python_lang() -> Language {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn go_lang() -> Arc<Language> {
|
||||
fn go_lang() -> Language {
|
||||
let debug_variables_query = include_str!("../../../languages/src/go/debugger.scm");
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Go".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["go".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Go".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["go".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_go::LANGUAGE.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap(),
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_go::LANGUAGE.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Test utility function for inline values testing
|
||||
@@ -1876,7 +1891,7 @@ async fn test_inline_values_util(
|
||||
before: &str,
|
||||
after: &str,
|
||||
active_debug_line: Option<usize>,
|
||||
language: Arc<Language>,
|
||||
language: Language,
|
||||
executor: BackgroundExecutor,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
@@ -2076,7 +2091,7 @@ async fn test_inline_values_util(
|
||||
.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_language(Some(language), cx);
|
||||
buffer.set_language(Some(Arc::new(language)), cx);
|
||||
});
|
||||
|
||||
let (editor, cx) = cx.add_window_view(|window, cx| {
|
||||
@@ -2261,61 +2276,55 @@ fn main() {
|
||||
.await;
|
||||
}
|
||||
|
||||
fn javascript_lang() -> Arc<Language> {
|
||||
fn javascript_lang() -> Language {
|
||||
let debug_variables_query = include_str!("../../../languages/src/javascript/debugger.scm");
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "JavaScript".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["js".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "JavaScript".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["js".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap(),
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn typescript_lang() -> Arc<Language> {
|
||||
fn typescript_lang() -> Language {
|
||||
let debug_variables_query = include_str!("../../../languages/src/typescript/debugger.scm");
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "TypeScript".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["ts".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "TypeScript".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["ts".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap(),
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn tsx_lang() -> Arc<Language> {
|
||||
fn tsx_lang() -> Language {
|
||||
let debug_variables_query = include_str!("../../../languages/src/tsx/debugger.scm");
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "TSX".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["tsx".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "TSX".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["tsx".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TSX.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap(),
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::LANGUAGE_TSX.into()),
|
||||
)
|
||||
.with_debug_variables_query(debug_variables_query)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use crate::{
|
||||
debugger_panel::DebugPanel,
|
||||
session::running::stack_frame_list::{
|
||||
StackFrameEntry, StackFrameFilter, stack_frame_filter_key,
|
||||
},
|
||||
session::running::stack_frame_list::{StackFrameEntry, StackFrameFilter},
|
||||
tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session},
|
||||
};
|
||||
use dap::{
|
||||
StackFrame,
|
||||
requests::{Scopes, StackTrace, Threads},
|
||||
};
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
use editor::{Editor, ToPoint as _};
|
||||
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
|
||||
use project::{FakeFs, Project};
|
||||
@@ -1088,180 +1085,3 @@ async fn test_stack_frame_filter(executor: BackgroundExecutor, cx: &mut TestAppC
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_stack_frame_filter_persistence(
|
||||
executor: BackgroundExecutor,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(executor.clone());
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/project"),
|
||||
json!({
|
||||
"src": {
|
||||
"test.js": "function main() { console.log('hello'); }",
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let workspace = init_test_workspace(&project, cx).await;
|
||||
let cx = &mut VisualTestContext::from_window(*workspace, cx);
|
||||
workspace
|
||||
.update(cx, |workspace, _, _| {
|
||||
workspace.set_random_database_id();
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let threads_response = dap::ThreadsResponse {
|
||||
threads: vec![dap::Thread {
|
||||
id: 1,
|
||||
name: "Thread 1".into(),
|
||||
}],
|
||||
};
|
||||
|
||||
let stack_trace_response = dap::StackTraceResponse {
|
||||
stack_frames: vec![StackFrame {
|
||||
id: 1,
|
||||
name: "main".into(),
|
||||
source: Some(dap::Source {
|
||||
name: Some("test.js".into()),
|
||||
path: Some(path!("/project/src/test.js").into()),
|
||||
source_reference: None,
|
||||
presentation_hint: None,
|
||||
origin: None,
|
||||
sources: None,
|
||||
adapter_data: None,
|
||||
checksums: None,
|
||||
}),
|
||||
line: 1,
|
||||
column: 1,
|
||||
end_line: None,
|
||||
end_column: None,
|
||||
can_restart: None,
|
||||
instruction_pointer_reference: None,
|
||||
module_id: None,
|
||||
presentation_hint: None,
|
||||
}],
|
||||
total_frames: None,
|
||||
};
|
||||
|
||||
let stopped_event = dap::StoppedEvent {
|
||||
reason: dap::StoppedEventReason::Pause,
|
||||
description: None,
|
||||
thread_id: Some(1),
|
||||
preserve_focus_hint: None,
|
||||
text: None,
|
||||
all_threads_stopped: None,
|
||||
hit_breakpoint_ids: None,
|
||||
};
|
||||
|
||||
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
|
||||
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
|
||||
let adapter_name = session.update(cx, |session, _| session.adapter());
|
||||
|
||||
client.on_request::<Threads, _>({
|
||||
let threads_response = threads_response.clone();
|
||||
move |_, _| Ok(threads_response.clone())
|
||||
});
|
||||
|
||||
client.on_request::<Scopes, _>(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] }));
|
||||
|
||||
client.on_request::<StackTrace, _>({
|
||||
let stack_trace_response = stack_trace_response.clone();
|
||||
move |_, _| Ok(stack_trace_response.clone())
|
||||
});
|
||||
|
||||
client
|
||||
.fake_event(dap::messages::Events::Stopped(stopped_event.clone()))
|
||||
.await;
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let stack_frame_list =
|
||||
active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| {
|
||||
debug_panel_item
|
||||
.running_state()
|
||||
.update(cx, |state, _| state.stack_frame_list().clone())
|
||||
});
|
||||
|
||||
stack_frame_list.update(cx, |stack_frame_list, _cx| {
|
||||
assert_eq!(
|
||||
stack_frame_list.list_filter(),
|
||||
StackFrameFilter::All,
|
||||
"Initial filter should be All"
|
||||
);
|
||||
});
|
||||
|
||||
stack_frame_list.update(cx, |stack_frame_list, cx| {
|
||||
stack_frame_list
|
||||
.toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx);
|
||||
assert_eq!(
|
||||
stack_frame_list.list_filter(),
|
||||
StackFrameFilter::OnlyUserFrames,
|
||||
"Filter should be OnlyUserFrames after toggle"
|
||||
);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let workspace_id = workspace
|
||||
.update(cx, |workspace, _window, _cx| workspace.database_id())
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("workspace id has to be some for this test to work properly");
|
||||
|
||||
let key = stack_frame_filter_key(&adapter_name, workspace_id);
|
||||
let stored_value = KEY_VALUE_STORE.read_kvp(&key).unwrap();
|
||||
assert_eq!(
|
||||
stored_value,
|
||||
Some(StackFrameFilter::OnlyUserFrames.into()),
|
||||
"Filter should be persisted in KVP store with key: {}",
|
||||
key
|
||||
);
|
||||
|
||||
client
|
||||
.fake_event(dap::messages::Events::Terminated(None))
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
let session2 = start_debug_session(&workspace, cx, |_| {}).unwrap();
|
||||
let client2 = session2.update(cx, |session, _| session.adapter_client().unwrap());
|
||||
|
||||
client2.on_request::<Threads, _>({
|
||||
let threads_response = threads_response.clone();
|
||||
move |_, _| Ok(threads_response.clone())
|
||||
});
|
||||
|
||||
client2.on_request::<Scopes, _>(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] }));
|
||||
|
||||
client2.on_request::<StackTrace, _>({
|
||||
let stack_trace_response = stack_trace_response.clone();
|
||||
move |_, _| Ok(stack_trace_response.clone())
|
||||
});
|
||||
|
||||
client2
|
||||
.fake_event(dap::messages::Events::Stopped(stopped_event.clone()))
|
||||
.await;
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let stack_frame_list2 =
|
||||
active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| {
|
||||
debug_panel_item
|
||||
.running_state()
|
||||
.update(cx, |state, _| state.stack_frame_list().clone())
|
||||
});
|
||||
|
||||
stack_frame_list2.update(cx, |stack_frame_list, _cx| {
|
||||
assert_eq!(
|
||||
stack_frame_list.list_filter(),
|
||||
StackFrameFilter::OnlyUserFrames,
|
||||
"Filter should be restored from KVP store in new session"
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -11,69 +11,7 @@ workspace = true
|
||||
[lib]
|
||||
path = "src/edit_prediction.rs"
|
||||
|
||||
[features]
|
||||
eval-support = []
|
||||
|
||||
[dependencies]
|
||||
ai_onboarding.workspace = true
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
db.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
open_ai.workspace = true
|
||||
postage.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
release_channel.workspace = true
|
||||
semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
thiserror.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
lsp.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
use language::{BufferSnapshot, Point};
|
||||
use std::ops::Range;
|
||||
|
||||
pub fn editable_and_context_ranges_for_cursor_position(
|
||||
position: Point,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_region_token_limit: usize,
|
||||
context_token_limit: usize,
|
||||
) -> (Range<Point>, Range<Point>) {
|
||||
let mut scope_range = position..position;
|
||||
let mut remaining_edit_tokens = editable_region_token_limit;
|
||||
|
||||
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
|
||||
let parent_tokens = guess_token_count(parent.byte_range().len());
|
||||
let parent_point_range = Point::new(
|
||||
parent.start_position().row as u32,
|
||||
parent.start_position().column as u32,
|
||||
)
|
||||
..Point::new(
|
||||
parent.end_position().row as u32,
|
||||
parent.end_position().column as u32,
|
||||
);
|
||||
if parent_point_range == scope_range {
|
||||
break;
|
||||
} else if parent_tokens <= editable_region_token_limit {
|
||||
scope_range = parent_point_range;
|
||||
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
|
||||
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
|
||||
(editable_range, context_range)
|
||||
}
|
||||
|
||||
fn expand_range(
|
||||
snapshot: &BufferSnapshot,
|
||||
range: Range<Point>,
|
||||
mut remaining_tokens: usize,
|
||||
) -> Range<Point> {
|
||||
let mut expanded_range = range;
|
||||
expanded_range.start.column = 0;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
loop {
|
||||
let mut expanded = false;
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.start.row > 0 {
|
||||
expanded_range.start.row -= 1;
|
||||
let line_tokens =
|
||||
guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
|
||||
expanded_range.end.row += 1;
|
||||
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
|
||||
let line_tokens = guess_token_count(expanded_range.end.column as usize);
|
||||
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
|
||||
expanded = true;
|
||||
}
|
||||
|
||||
if !expanded {
|
||||
break;
|
||||
}
|
||||
}
|
||||
expanded_range
|
||||
}
|
||||
|
||||
/// Typical number of string bytes per token for the purposes of limiting model input. This is
|
||||
/// intentionally low to err on the side of underestimating limits.
|
||||
pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
|
||||
|
||||
pub fn guess_token_count(bytes: usize) -> usize {
|
||||
bytes / BYTES_PER_TOKEN_GUESS
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,340 +0,0 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
|
||||
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub struct Mercury {
|
||||
pub api_token: Shared<Task<Option<String>>>,
|
||||
}
|
||||
|
||||
impl Mercury {
|
||||
pub fn new(cx: &App) -> Self {
|
||||
Mercury {
|
||||
api_token: load_api_token(cx).shared(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
|
||||
self.api_token = Task::ready(api_token.clone()).shared();
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
_recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
_diagnostic_search_range: Range<Point>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let http_client = cx.http_client();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let (editable_range, context_range) =
|
||||
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
&snapshot,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let offset_range = editable_range.to_offset(&snapshot);
|
||||
let prompt = build_prompt(
|
||||
&events,
|
||||
&related_files,
|
||||
&snapshot,
|
||||
full_path.as_ref(),
|
||||
cursor_point,
|
||||
editable_range,
|
||||
context_range.clone(),
|
||||
);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(
|
||||
context_range.start.row,
|
||||
),
|
||||
text: snapshot
|
||||
.text_for_range(context_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
cursor_path: full_path.clone(),
|
||||
};
|
||||
|
||||
let request_body = open_ai::Request {
|
||||
model: "mercury-coder".into(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(prompt),
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
prompt_cache_key: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
let buf = serde_json::to_vec(&request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
.uri(MERCURY_API_URL)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Connection", "keep-alive")
|
||||
.method(Method::POST)
|
||||
.body(body)
|
||||
.context("Failed to create request")?;
|
||||
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("Failed to send request")?;
|
||||
|
||||
let mut body: Vec<u8> = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("Failed to read response body")?;
|
||||
|
||||
let response_received_at = Instant::now();
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Request failed with status: {:?}\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&body),
|
||||
);
|
||||
};
|
||||
|
||||
let mut response: open_ai::Response =
|
||||
serde_json::from_slice(&body).context("Failed to parse response")?;
|
||||
|
||||
let id = mem::take(&mut response.id);
|
||||
let response_str = text_from_response(response).unwrap_or_default();
|
||||
|
||||
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
|
||||
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
|
||||
|
||||
let mut edits = Vec::new();
|
||||
const NO_PREDICTION_OUTPUT: &str = "None";
|
||||
|
||||
if response_str != NO_PREDICTION_OUTPUT {
|
||||
let old_text = snapshot
|
||||
.text_for_range(offset_range.clone())
|
||||
.collect::<String>();
|
||||
edits.extend(
|
||||
language::text_diff(&old_text, &response_str)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(offset_range.start + range.start)
|
||||
..snapshot.anchor_before(offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) =
|
||||
result.await.context("Mercury edit prediction failed")?;
|
||||
anyhow::Ok(Some(
|
||||
EditPredictionResult::new(
|
||||
EditPredictionId(id.into()),
|
||||
&buffer,
|
||||
&old_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
response_received_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
events: &[Arc<Event>],
|
||||
related_files: &[RelatedFile],
|
||||
cursor_buffer: &BufferSnapshot,
|
||||
cursor_buffer_path: &Path,
|
||||
cursor_point: Point,
|
||||
editable_range: Range<Point>,
|
||||
context_range: Range<Point>,
|
||||
) -> String {
|
||||
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
|
||||
const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
|
||||
const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
|
||||
const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
|
||||
const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
|
||||
const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
|
||||
const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
|
||||
const CURSOR_TAG: &str = "<|cursor|>";
|
||||
const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
|
||||
const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|
||||
|prompt| {
|
||||
for related_file in related_files {
|
||||
for related_excerpt in &related_file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
|
||||
prompt.push_str(related_file.path.path.as_unix_str());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&related_excerpt.text.to_string());
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
|
||||
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
|
||||
let prefix_range = context_range.start..editable_range.start;
|
||||
let suffix_range = editable_range.end..context_range.end;
|
||||
|
||||
prompt.extend(cursor_buffer.text_for_range(prefix_range));
|
||||
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
|
||||
let range_before_cursor = editable_range.start..cursor_point;
|
||||
let range_after_cursor = cursor_point..editable_range.end;
|
||||
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
|
||||
prompt.push_str(CURSOR_TAG);
|
||||
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
|
||||
});
|
||||
prompt.extend(cursor_buffer.text_for_range(suffix_range));
|
||||
},
|
||||
);
|
||||
|
||||
push_delimited(
|
||||
&mut prompt,
|
||||
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|
||||
|prompt| {
|
||||
for event in events {
|
||||
writeln!(prompt, "{event}").unwrap();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
|
||||
prompt.push_str(delimiters.start);
|
||||
cb(prompt);
|
||||
prompt.push_str(delimiters.end);
|
||||
}
|
||||
|
||||
pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
|
||||
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
|
||||
|
||||
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
|
||||
if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
|
||||
.ok()
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
return Task::ready(Some(api_token));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, credentials) = credentials_provider
|
||||
.read_credentials(MERCURY_CREDENTIALS_URL, &cx)
|
||||
.await
|
||||
.ok()??;
|
||||
String::from_utf8(credentials).ok()
|
||||
})
|
||||
}
|
||||
|
||||
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Some(api_token) = api_token {
|
||||
credentials_provider
|
||||
.write_credentials(
|
||||
MERCURY_CREDENTIALS_URL,
|
||||
MERCURY_CREDENTIALS_USERNAME,
|
||||
api_token.as_bytes(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.context("Failed to save Mercury API token to system keychain")
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(MERCURY_CREDENTIALS_URL, cx)
|
||||
.await
|
||||
.context("Failed to delete Mercury API token from system keychain")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
|
||||
let choice = res.choices.pop()?;
|
||||
let output_text = match choice.message {
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(content)),
|
||||
..
|
||||
} => content,
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Multipart(mut content)),
|
||||
..
|
||||
} => {
|
||||
if content.is_empty() {
|
||||
log::error!("No output from Baseten completion response");
|
||||
return None;
|
||||
}
|
||||
|
||||
match content.remove(0) {
|
||||
open_ai::MessagePart::Text { text } => text,
|
||||
open_ai::MessagePart::Image { .. } => {
|
||||
log::error!("Expected text, got an image");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
log::error!("Invalid response message: {:?}", choice.message);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(output_text)
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
#[cfg(feature = "eval-support")]
|
||||
use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
EditPredictionRequestedDebugEvent, EditPredictionStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, Line};
|
||||
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Entity, Task, prelude::*};
|
||||
use language::{Anchor, BufferSnapshot};
|
||||
use language::{Buffer, Point, ToOffset as _, ToPoint};
|
||||
use project::{Project, ProjectItem as _};
|
||||
use release_channel::AppVersion;
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
pub fn request_prediction_with_zeta2(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
mut included_files: Vec<RelatedFile>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let options = store.options.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let Some((excerpt_path, active_project_path)) = active_snapshot
|
||||
.file()
|
||||
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
|
||||
.zip(active_buffer.read(cx).project_path(cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("No file path for excerpt")));
|
||||
};
|
||||
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let debug_tx = store.debug_tx.clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
|
||||
|
||||
// TODO data collection
|
||||
let can_collect_data = file
|
||||
.as_ref()
|
||||
.map_or(false, |file| store.can_collect_file(project, file, cx));
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
let active_buffer = active_buffer.clone();
|
||||
async move {
|
||||
let cursor_offset = position.to_offset(&active_snapshot);
|
||||
let cursor_point = cursor_offset.to_point(&active_snapshot);
|
||||
|
||||
let before_retrieval = Instant::now();
|
||||
|
||||
let excerpt_options = options.context;
|
||||
|
||||
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
&active_snapshot,
|
||||
&excerpt_options,
|
||||
) else {
|
||||
return Ok((None, None));
|
||||
};
|
||||
|
||||
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
|
||||
..active_snapshot.anchor_before(excerpt.range.end);
|
||||
let related_excerpt = RelatedExcerpt {
|
||||
anchor_range: excerpt_anchor_range.clone(),
|
||||
point_range: Point::new(excerpt.line_range.start.0, 0)
|
||||
..Point::new(excerpt.line_range.end.0, 0),
|
||||
text: active_snapshot.as_rope().slice(excerpt.range),
|
||||
};
|
||||
|
||||
if let Some(buffer_ix) = included_files
|
||||
.iter()
|
||||
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
|
||||
{
|
||||
let file = &mut included_files[buffer_ix];
|
||||
file.excerpts.push(related_excerpt);
|
||||
file.merge_excerpts();
|
||||
let last_ix = included_files.len() - 1;
|
||||
included_files.swap(buffer_ix, last_ix);
|
||||
} else {
|
||||
let active_file = RelatedFile {
|
||||
path: active_project_path,
|
||||
buffer: active_buffer.downgrade(),
|
||||
excerpts: vec![related_excerpt],
|
||||
max_row: active_snapshot.max_point().row,
|
||||
};
|
||||
included_files.push(active_file);
|
||||
}
|
||||
|
||||
let included_files = included_files
|
||||
.iter()
|
||||
.map(|related_file| predict_edits_v3::RelatedFile {
|
||||
path: Arc::from(related_file.path.path.as_std_path()),
|
||||
max_row: Line(related_file.max_row),
|
||||
excerpts: related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| predict_edits_v3::Excerpt {
|
||||
start_line: Line(excerpt.point_range.start.row),
|
||||
text: excerpt.text.to_string().into(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cloud_request = predict_edits_v3::PredictEditsRequest {
|
||||
excerpt_path,
|
||||
excerpt: String::new(),
|
||||
excerpt_line_range: Line(0)..Line(0),
|
||||
excerpt_range: 0..0,
|
||||
cursor_point: predict_edits_v3::Point {
|
||||
line: predict_edits_v3::Line(cursor_point.row),
|
||||
column: cursor_point.column,
|
||||
},
|
||||
related_files: included_files,
|
||||
events,
|
||||
can_collect_data,
|
||||
debug_info: debug_tx.is_some(),
|
||||
prompt_max_bytes: Some(options.max_prompt_bytes),
|
||||
prompt_format: options.prompt_format,
|
||||
excerpt_parent: None,
|
||||
git_info: None,
|
||||
trigger,
|
||||
};
|
||||
|
||||
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
included_files: cloud_request.related_files,
|
||||
events: cloud_request.events,
|
||||
cursor_point: cloud_request.cursor_point,
|
||||
cursor_path: cloud_request.excerpt_path,
|
||||
};
|
||||
|
||||
let retrieval_time = Instant::now() - before_retrieval;
|
||||
|
||||
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionRequested(
|
||||
EditPredictionRequestedDebugEvent {
|
||||
inputs: inputs.clone(),
|
||||
retrieval_time,
|
||||
buffer: active_buffer.downgrade(),
|
||||
local_prompt: match prompt_result.as_ref() {
|
||||
Ok(prompt) => Ok(prompt.clone()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
position,
|
||||
response_rx,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
Some(response_tx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((Err("Request skipped".to_string()), Duration::ZERO))
|
||||
.ok();
|
||||
}
|
||||
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
|
||||
}
|
||||
|
||||
let prompt = prompt_result?;
|
||||
let generation_params =
|
||||
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
|
||||
let request = open_ai::Request {
|
||||
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(prompt),
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.or(Some(0.7)),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
prompt_cache_key: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
log::trace!("Sending edit prediction request");
|
||||
|
||||
let before_request = Instant::now();
|
||||
let response = EditPredictionStore::send_raw_llm_request(
|
||||
request,
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache,
|
||||
#[cfg(feature = "eval-support")]
|
||||
EvalCacheEntryKind::Prediction,
|
||||
)
|
||||
.await;
|
||||
let received_response_at = Instant::now();
|
||||
let request_time = received_response_at - before_request;
|
||||
|
||||
log::trace!("Got edit prediction response");
|
||||
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((
|
||||
response
|
||||
.as_ref()
|
||||
.map_err(|err| err.to_string())
|
||||
.map(|response| response.0.clone()),
|
||||
request_time,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (res, usage) = response?;
|
||||
let request_id = EditPredictionId(res.id.clone().into());
|
||||
let Some(mut output_text) = text_from_response(res) else {
|
||||
return Ok((Some((request_id, None)), usage));
|
||||
};
|
||||
|
||||
if output_text.contains(CURSOR_MARKER) {
|
||||
log::trace!("Stripping out {CURSOR_MARKER} from response");
|
||||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let get_buffer_from_context = |path: &Path| {
|
||||
if Some(path) == active_file_full_path.as_deref() {
|
||||
Some((
|
||||
&active_snapshot,
|
||||
std::slice::from_ref(&excerpt_anchor_range),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (_, edits) = match options.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
if output_text.contains("--- a/\n+++ b/\nNo edits") {
|
||||
let edits = vec![];
|
||||
(&active_snapshot, edits)
|
||||
} else {
|
||||
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
_ => {
|
||||
bail!("unsupported prompt format {}", options.prompt_format)
|
||||
}
|
||||
};
|
||||
|
||||
anyhow::Ok((
|
||||
Some((
|
||||
request_id,
|
||||
Some((
|
||||
inputs,
|
||||
active_buffer,
|
||||
active_snapshot.clone(),
|
||||
edits,
|
||||
received_response_at,
|
||||
)),
|
||||
)),
|
||||
usage,
|
||||
))
|
||||
}
|
||||
});
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let Some((id, prediction)) =
|
||||
EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
|
||||
prediction
|
||||
else {
|
||||
return Ok(Some(EditPredictionResult {
|
||||
id,
|
||||
prediction: Err(EditPredictionRejectReason::Empty),
|
||||
}));
|
||||
};
|
||||
|
||||
Ok(Some(
|
||||
EditPredictionResult::new(
|
||||
id,
|
||||
&edited_buffer,
|
||||
&edited_buffer_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
received_response_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
})
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "edit_prediction_ui"
|
||||
name = "edit_prediction_button"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
@@ -9,43 +9,35 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/edit_prediction_ui.rs"
|
||||
path = "src/edit_prediction_button.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
markdown.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
settings.workspace = true
|
||||
supermaven.workspace = true
|
||||
telemetry.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
menu.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
@@ -1,14 +1,16 @@
|
||||
mod sweep_api_token_modal;
|
||||
|
||||
pub use sweep_api_token_modal::SweepApiKeyModal;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use cloud_llm_client::UsageLimit;
|
||||
use codestral::CodestralEditPredictionDelegate;
|
||||
use codestral::CodestralCompletionProvider;
|
||||
use copilot::{Copilot, Status};
|
||||
use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
use edit_prediction_types::EditPredictionDelegateHandle;
|
||||
use editor::{
|
||||
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
|
||||
};
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle,
|
||||
@@ -23,7 +25,6 @@ use language::{
|
||||
use project::DisableAiSettings;
|
||||
use regex::Regex;
|
||||
use settings::{
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
|
||||
update_settings_file,
|
||||
@@ -43,11 +44,7 @@ use workspace::{
|
||||
notifications::NotificationId,
|
||||
};
|
||||
use zed_actions::OpenBrowser;
|
||||
|
||||
use crate::{
|
||||
ExternalProviderApiKeyModal, RatePredictions,
|
||||
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
|
||||
};
|
||||
use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
|
||||
actions!(
|
||||
edit_prediction,
|
||||
@@ -70,7 +67,7 @@ pub struct EditPredictionButton {
|
||||
editor_focus_handle: Option<FocusHandle>,
|
||||
language: Option<Arc<Language>>,
|
||||
file: Option<Arc<dyn File>>,
|
||||
edit_prediction_provider: Option<Arc<dyn EditPredictionDelegateHandle>>,
|
||||
edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
user_store: Entity<UserStore>,
|
||||
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
@@ -247,7 +244,7 @@ impl Render for EditPredictionButton {
|
||||
|
||||
EditPredictionProvider::Codestral => {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
|
||||
let has_api_key = CodestralCompletionProvider::has_api_key(cx);
|
||||
let fs = self.fs.clone();
|
||||
let this = cx.weak_entity();
|
||||
|
||||
@@ -312,34 +309,24 @@ impl Render for EditPredictionButton {
|
||||
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
|
||||
let ep_icon;
|
||||
let mut missing_token = false;
|
||||
let is_sweep = matches!(
|
||||
provider,
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
|
||||
)
|
||||
);
|
||||
|
||||
match provider {
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::SweepAi;
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
ep_icon = IconName::Inception;
|
||||
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
|
||||
}
|
||||
_ => {
|
||||
ep_icon = if enabled {
|
||||
IconName::ZedPredict
|
||||
} else {
|
||||
IconName::ZedPredictDisabled
|
||||
};
|
||||
}
|
||||
let sweep_missing_token = is_sweep
|
||||
&& !zeta::Zeta::try_global(cx)
|
||||
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
|
||||
|
||||
let zeta_icon = match (is_sweep, enabled) {
|
||||
(true, _) => IconName::SweepAi,
|
||||
(false, true) => IconName::ZedPredict,
|
||||
(false, false) => IconName::ZedPredictDisabled,
|
||||
};
|
||||
|
||||
if edit_prediction::should_show_upsell_modal() {
|
||||
if zeta::should_show_upsell_modal() {
|
||||
let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
|
||||
"Choose a Plan"
|
||||
} else {
|
||||
@@ -347,7 +334,7 @@ impl Render for EditPredictionButton {
|
||||
};
|
||||
|
||||
return div().child(
|
||||
IconButton::new("zed-predict-pending-button", ep_icon)
|
||||
IconButton::new("zed-predict-pending-button", zeta_icon)
|
||||
.shape(IconButtonShape::Square)
|
||||
.indicator(Indicator::dot().color(Color::Muted))
|
||||
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
|
||||
@@ -380,7 +367,7 @@ impl Render for EditPredictionButton {
|
||||
let show_editor_predictions = self.editor_show_predictions;
|
||||
let user = self.user_store.read(cx).current_user();
|
||||
|
||||
let indicator_color = if missing_token {
|
||||
let indicator_color = if sweep_missing_token {
|
||||
Some(Color::Error)
|
||||
} else if enabled && (!show_editor_predictions || over_limit) {
|
||||
Some(if over_limit {
|
||||
@@ -392,7 +379,7 @@ impl Render for EditPredictionButton {
|
||||
None
|
||||
};
|
||||
|
||||
let icon_button = IconButton::new("zed-predict-pending-button", ep_icon)
|
||||
let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
|
||||
.shape(IconButtonShape::Square)
|
||||
.when_some(indicator_color, |this, color| {
|
||||
this.indicator(Indicator::dot().color(color))
|
||||
@@ -432,13 +419,13 @@ impl Render for EditPredictionButton {
|
||||
|
||||
let this = cx.weak_entity();
|
||||
|
||||
let mut popover_menu = PopoverMenu::new("edit-prediction")
|
||||
let mut popover_menu = PopoverMenu::new("zeta")
|
||||
.when(user.is_some(), |popover_menu| {
|
||||
let this = this.clone();
|
||||
|
||||
popover_menu.menu(move |window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.build_edit_prediction_context_menu(provider, window, cx)
|
||||
this.build_zeta_context_menu(provider, window, cx)
|
||||
})
|
||||
.ok()
|
||||
})
|
||||
@@ -498,7 +485,7 @@ impl EditPredictionButton {
|
||||
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
||||
.detach();
|
||||
|
||||
CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx);
|
||||
CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx);
|
||||
|
||||
Self {
|
||||
editor_subscription: None,
|
||||
@@ -533,7 +520,7 @@ impl EditPredictionButton {
|
||||
}
|
||||
}
|
||||
|
||||
if CodestralEditPredictionDelegate::has_api_key(cx) {
|
||||
if CodestralCompletionProvider::has_api_key(cx) {
|
||||
providers.push(EditPredictionProvider::Codestral);
|
||||
}
|
||||
|
||||
@@ -543,12 +530,6 @@ impl EditPredictionButton {
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<MercuryFeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
));
|
||||
}
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
@@ -618,8 +599,8 @@ impl EditPredictionButton {
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
|
||||
let has_api_token = zeta::Zeta::try_global(cx)
|
||||
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
@@ -645,66 +626,7 @@ impl EditPredictionButton {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.sweep_ai
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
});
|
||||
});
|
||||
};
|
||||
} else {
|
||||
set_completion_provider(fs.clone(), cx, provider);
|
||||
}
|
||||
});
|
||||
|
||||
menu.item(entry)
|
||||
}
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
let entry = if has_api_token {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.toggleable(IconPosition::Start, is_current)
|
||||
} else {
|
||||
ContextMenuEntry::new("Mercury")
|
||||
.icon(IconName::XCircle)
|
||||
.icon_color(Color::Error)
|
||||
.documentation_aside(
|
||||
DocumentationSide::Left,
|
||||
DocumentationEdge::Bottom,
|
||||
|_| {
|
||||
Label::new("Click to configure your Mercury API token")
|
||||
.into_any_element()
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let entry = entry.handler(move |window, cx| {
|
||||
if should_open_modal {
|
||||
if let Some(workspace) = window.root::<Workspace>().flatten() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
ExternalProviderApiKeyModal::new(
|
||||
window,
|
||||
cx,
|
||||
|api_key, store, cx| {
|
||||
store
|
||||
.mercury
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)
|
||||
SweepApiKeyModal::new(window, cx)
|
||||
});
|
||||
});
|
||||
};
|
||||
@@ -1025,8 +947,8 @@ impl EditPredictionButton {
|
||||
)
|
||||
.context(editor_focus_handle)
|
||||
.when(
|
||||
cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
|
||||
|this| this.action("Rate Predictions", RatePredictions.boxed_clone()),
|
||||
cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>(),
|
||||
|this| this.action("Rate Completions", RateCompletions.boxed_clone()),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1094,7 +1016,7 @@ impl EditPredictionButton {
|
||||
})
|
||||
}
|
||||
|
||||
fn build_edit_prediction_context_menu(
|
||||
fn build_zeta_context_menu(
|
||||
&self,
|
||||
provider: EditPredictionProvider,
|
||||
window: &mut Window,
|
||||
@@ -1183,33 +1105,9 @@ impl EditPredictionButton {
|
||||
.separator();
|
||||
}
|
||||
|
||||
menu = self.build_language_settings_menu(menu, window, cx);
|
||||
let menu = self.build_language_settings_menu(menu, window, cx);
|
||||
let menu = self.add_provider_switching_section(menu, provider, cx);
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
let settings = all_language_settings(None, cx);
|
||||
let context_retrieval = settings.edit_predictions.use_context;
|
||||
menu = menu.separator().header("Context Retrieval").item(
|
||||
ContextMenuEntry::new("Enable Context Retrieval")
|
||||
.toggleable(IconPosition::Start, context_retrieval)
|
||||
.action(workspace::ToggleEditPrediction.boxed_clone())
|
||||
.handler({
|
||||
let fs = self.fs.clone();
|
||||
move |_, cx| {
|
||||
update_settings_file(fs.clone(), cx, move |settings, _| {
|
||||
settings
|
||||
.project
|
||||
.all_languages
|
||||
.features
|
||||
.get_or_insert_default()
|
||||
.experimental_edit_prediction_context_retrieval =
|
||||
Some(!context_retrieval)
|
||||
});
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
menu = self.add_provider_switching_section(menu, provider, cx);
|
||||
menu
|
||||
})
|
||||
}
|
||||
@@ -1,29 +1,23 @@
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::{
|
||||
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
|
||||
};
|
||||
use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use workspace::ModalView;
|
||||
use zeta::Zeta;
|
||||
|
||||
pub struct ExternalProviderApiKeyModal {
|
||||
pub struct SweepApiKeyModal {
|
||||
api_key_input: Entity<InputField>,
|
||||
focus_handle: FocusHandle,
|
||||
on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
|
||||
}
|
||||
|
||||
impl ExternalProviderApiKeyModal {
|
||||
pub fn new(
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
|
||||
) -> Self {
|
||||
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
|
||||
impl SweepApiKeyModal {
|
||||
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
|
||||
|
||||
Self {
|
||||
api_key_input,
|
||||
focus_handle: cx.focus_handle(),
|
||||
on_confirm: Box::new(on_confirm),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,35 +29,39 @@ impl ExternalProviderApiKeyModal {
|
||||
let api_key = self.api_key_input.read(cx).text(cx);
|
||||
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
|
||||
|
||||
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
|
||||
ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
|
||||
if let Some(zeta) = Zeta::try_global(cx) {
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.sweep_ai
|
||||
.set_api_token(api_key, cx)
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
}
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
|
||||
impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
|
||||
|
||||
impl ModalView for ExternalProviderApiKeyModal {}
|
||||
impl ModalView for SweepApiKeyModal {}
|
||||
|
||||
impl Focusable for ExternalProviderApiKeyModal {
|
||||
impl Focusable for SweepApiKeyModal {
|
||||
fn focus_handle(&self, _cx: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ExternalProviderApiKeyModal {
|
||||
impl Render for SweepApiKeyModal {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.key_context("ExternalApiKeyModal")
|
||||
.key_context("SweepApiKeyModal")
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::confirm))
|
||||
.elevation_2(cx)
|
||||
.w(px(400.))
|
||||
.p_4()
|
||||
.gap_3()
|
||||
.child(Headline::new("API Token").size(HeadlineSize::Small))
|
||||
.child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
|
||||
.child(self.api_key_input.clone())
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
|
||||
|
||||
#[derive(Debug, Clone, Default, clap::ValueEnum)]
|
||||
pub enum ContextType {
|
||||
#[default]
|
||||
CurrentFile,
|
||||
}
|
||||
|
||||
const MAX_CONTEXT_SIZE: usize = 32768;
|
||||
|
||||
pub fn collect_context(
|
||||
context_type: &ContextType,
|
||||
worktree_dir: &Path,
|
||||
cursor: SourceLocation,
|
||||
) -> String {
|
||||
let context = match context_type {
|
||||
ContextType::CurrentFile => {
|
||||
let file_path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
|
||||
|
||||
let context = add_special_tags(&context, worktree_dir, cursor);
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
let region_end_offset = context.find(TeacherModel::REGION_END);
|
||||
|
||||
if context.len() <= MAX_CONTEXT_SIZE {
|
||||
return context;
|
||||
}
|
||||
|
||||
if let Some(region_end_offset) = region_end_offset
|
||||
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
|
||||
{
|
||||
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
|
||||
format!(
|
||||
"[...{} bytes truncated]\n{}\n",
|
||||
to_truncate,
|
||||
&context[to_truncate..]
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}\n[...{} bytes truncated]\n",
|
||||
&context[..MAX_CONTEXT_SIZE],
|
||||
context.len() - MAX_CONTEXT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add <|editable_region_start/end|> tags
|
||||
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
|
||||
let path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let file = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let lines = file.lines().collect::<Vec<_>>();
|
||||
let cursor_row = cursor.point.row as usize;
|
||||
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
|
||||
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
|
||||
|
||||
let snippet = lines[start_line..end_line].join("\n");
|
||||
|
||||
if context.contains(&snippet) {
|
||||
let mut cursor_line = lines[cursor_row].to_string();
|
||||
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
|
||||
|
||||
let mut snippet_with_tags_lines = vec![];
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_START);
|
||||
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
|
||||
snippet_with_tags_lines.push(&cursor_line);
|
||||
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_END);
|
||||
let snippet_with_tags = snippet_with_tags_lines.join("\n");
|
||||
|
||||
context.replace(&snippet, &snippet_with_tags)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Can't find area around the cursor in the context; proceeding without special tags"
|
||||
);
|
||||
context.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strip_special_tags(context: &str) -> String {
|
||||
context
|
||||
.replace(TeacherModel::REGION_START, "")
|
||||
.replace(TeacherModel::REGION_END, "")
|
||||
.replace(TeacherModel::USER_CURSOR, "")
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
DistillArguments,
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::ContextType,
|
||||
llm_client::LlmClient,
|
||||
teacher::{TeacherModel, TeacherOutput},
|
||||
},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SplitCommit {
|
||||
repo_url: String,
|
||||
commit_sha: String,
|
||||
edit_history: String,
|
||||
expected_patch: String,
|
||||
cursor_position: String,
|
||||
}
|
||||
|
||||
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
|
||||
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
|
||||
.expect("Failed to read split commit dataset")
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
|
||||
.collect();
|
||||
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
|
||||
let llm_client = if let Some(cache_path) = arguments.batch {
|
||||
LlmClient::batch(&cache_path, http_client)?
|
||||
} else {
|
||||
LlmClient::plain(http_client)?
|
||||
};
|
||||
|
||||
let mut teacher = TeacherModel::new(
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
llm_client,
|
||||
);
|
||||
|
||||
let mut num_marked_for_batching = 0;
|
||||
|
||||
for commit in split_commits {
|
||||
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
|
||||
println!("{}", serde_json::to_string(&distilled)?);
|
||||
} else {
|
||||
if num_marked_for_batching == 0 {
|
||||
log::warn!("Marked for batching");
|
||||
}
|
||||
num_marked_for_batching += 1;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{} requests are marked for batching",
|
||||
num_marked_for_batching
|
||||
);
|
||||
let llm_client = teacher.client;
|
||||
llm_client.sync_batches().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn distill_one(
|
||||
teacher: &mut TeacherModel,
|
||||
commit: SplitCommit,
|
||||
) -> Result<Option<TeacherOutput>> {
|
||||
let cursor: SourceLocation = commit
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let path = cursor.path.to_rel_path_buf();
|
||||
|
||||
let example = Example {
|
||||
repository_url: commit.repo_url,
|
||||
revision: commit.commit_sha,
|
||||
uncommitted_diff: commit.edit_history.clone(),
|
||||
cursor_path: path.as_std_path().to_path_buf(),
|
||||
cursor_position: commit.cursor_position,
|
||||
edit_history: commit.edit_history, // todo: trim
|
||||
expected_patch: commit.expected_patch,
|
||||
};
|
||||
|
||||
let prediction = teacher.predict(example).await;
|
||||
|
||||
prediction
|
||||
}
|
||||
@@ -1,417 +0,0 @@
|
||||
use anthropic::{
|
||||
ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
|
||||
Response as AnthropicResponse, Role, non_streaming_completion,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use sqlez::bindable::Bind;
|
||||
use sqlez::bindable::StaticColumnCount;
|
||||
use sqlez_macros::sql;
|
||||
use std::hash::Hash;
|
||||
use std::hash::Hasher;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
Ok(Self {
|
||||
http_client,
|
||||
api_key,
|
||||
})
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<AnthropicResponse> {
|
||||
let request = AnthropicRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
thinking: None,
|
||||
tool_choice: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
let response = non_streaming_completion(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
request,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BatchingLlmClient {
|
||||
connection: sqlez::connection::Connection,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
struct CacheRow {
|
||||
request_hash: String,
|
||||
request: Option<String>,
|
||||
response: Option<String>,
|
||||
batch_id: Option<String>,
|
||||
}
|
||||
|
||||
impl StaticColumnCount for CacheRow {
|
||||
fn column_count() -> usize {
|
||||
4
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind for CacheRow {
|
||||
fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
|
||||
let next_index = statement.bind(&self.request_hash, start_index)?;
|
||||
let next_index = statement.bind(&self.request, next_index)?;
|
||||
let next_index = statement.bind(&self.response, next_index)?;
|
||||
let next_index = statement.bind(&self.batch_id, next_index)?;
|
||||
Ok(next_index)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct SerializableRequest {
|
||||
model: String,
|
||||
max_tokens: u64,
|
||||
messages: Vec<SerializableMessage>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct SerializableMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl BatchingLlmClient {
|
||||
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path);
|
||||
let mut statement = sqlez::statement::Statement::prepare(
|
||||
&connection,
|
||||
indoc! {"
|
||||
CREATE TABLE IF NOT EXISTS cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
request TEXT,
|
||||
response TEXT,
|
||||
batch_id TEXT
|
||||
);
|
||||
"},
|
||||
)?;
|
||||
statement.exec()?;
|
||||
drop(statement);
|
||||
|
||||
Ok(Self {
|
||||
connection,
|
||||
http_client,
|
||||
api_key,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn lookup(
|
||||
&self,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: &[Message],
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
let request_hash_str = Self::request_hash(model, max_tokens, messages);
|
||||
let response: Vec<String> = self.connection.select_bound(
|
||||
&sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
|
||||
)?(request_hash_str.as_str())?;
|
||||
Ok(response
|
||||
.into_iter()
|
||||
.next()
|
||||
.and_then(|text| serde_json::from_str(&text).ok()))
|
||||
}
|
||||
|
||||
pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
|
||||
let request_hash = Self::request_hash(model, max_tokens, messages);
|
||||
|
||||
let serializable_messages: Vec<SerializableMessage> = messages
|
||||
.iter()
|
||||
.map(|msg| SerializableMessage {
|
||||
role: match msg.role {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Assistant => "assistant".to_string(),
|
||||
},
|
||||
content: message_content_to_string(&msg.content),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let serializable_request = SerializableRequest {
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages: serializable_messages,
|
||||
};
|
||||
|
||||
let request = Some(serde_json::to_string(&serializable_request)?);
|
||||
let cache_row = CacheRow {
|
||||
request_hash,
|
||||
request,
|
||||
response: None,
|
||||
batch_id: None,
|
||||
};
|
||||
self.connection.exec_bound(sql!(
|
||||
INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
|
||||
cache_row,
|
||||
)
|
||||
}
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
let response = self.lookup(&model, max_tokens, &messages)?;
|
||||
if let Some(response) = response {
|
||||
return Ok(Some(response));
|
||||
}
|
||||
|
||||
self.mark_for_batch(&model, max_tokens, &messages)?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Uploads pending requests as a new batch; downloads finished batches if any.
|
||||
async fn sync_batches(&self) -> Result<()> {
|
||||
self.upload_pending_requests().await?;
|
||||
self.download_finished_batches().await
|
||||
}
|
||||
|
||||
async fn download_finished_batches(&self) -> Result<()> {
|
||||
let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
|
||||
let batch_ids: Vec<String> = self.connection.select(q)?()?;
|
||||
|
||||
for batch_id in batch_ids {
|
||||
let batch_status = anthropic::batches::retrieve_batch(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
&batch_id,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
log::info!(
|
||||
"Batch {} status: {}",
|
||||
batch_id,
|
||||
batch_status.processing_status
|
||||
);
|
||||
|
||||
if batch_status.processing_status == "ended" {
|
||||
let results = anthropic::batches::retrieve_batch_results(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
&batch_id,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
let mut success_count = 0;
|
||||
for result in results {
|
||||
let request_hash = result
|
||||
.custom_id
|
||||
.strip_prefix("req_hash_")
|
||||
.unwrap_or(&result.custom_id)
|
||||
.to_string();
|
||||
|
||||
match result.result {
|
||||
anthropic::batches::BatchResult::Succeeded { message } => {
|
||||
let response_json = serde_json::to_string(&message)?;
|
||||
let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
|
||||
self.connection.exec_bound(q)?((response_json, request_hash))?;
|
||||
success_count += 1;
|
||||
}
|
||||
anthropic::batches::BatchResult::Errored { error } => {
|
||||
log::error!("Batch request {} failed: {:?}", request_hash, error);
|
||||
}
|
||||
anthropic::batches::BatchResult::Canceled => {
|
||||
log::warn!("Batch request {} was canceled", request_hash);
|
||||
}
|
||||
anthropic::batches::BatchResult::Expired => {
|
||||
log::warn!("Batch request {} expired", request_hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
log::info!("Uploaded {} successful requests", success_count);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upload_pending_requests(&self) -> Result<String> {
|
||||
let q = sql!(
|
||||
SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
|
||||
);
|
||||
|
||||
let rows: Vec<(String, String)> = self.connection.select(q)?()?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let batch_requests = rows
|
||||
.iter()
|
||||
.map(|(hash, request_str)| {
|
||||
let serializable_request: SerializableRequest =
|
||||
serde_json::from_str(&request_str).unwrap();
|
||||
|
||||
let messages: Vec<Message> = serializable_request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| Message {
|
||||
role: match msg.role.as_str() {
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
_ => Role::User,
|
||||
},
|
||||
content: vec![RequestContent::Text {
|
||||
text: msg.content,
|
||||
cache_control: None,
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let params = AnthropicRequest {
|
||||
model: serializable_request.model,
|
||||
max_tokens: serializable_request.max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
thinking: None,
|
||||
tool_choice: None,
|
||||
system: None,
|
||||
metadata: None,
|
||||
stop_sequences: Vec::new(),
|
||||
temperature: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
let custom_id = format!("req_hash_{}", hash);
|
||||
anthropic::batches::BatchRequest { custom_id, params }
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let batch_len = batch_requests.len();
|
||||
let batch = anthropic::batches::create_batch(
|
||||
self.http_client.as_ref(),
|
||||
ANTHROPIC_API_URL,
|
||||
&self.api_key,
|
||||
anthropic::batches::CreateBatchRequest {
|
||||
requests: batch_requests,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{:?}", e))?;
|
||||
|
||||
let q = sql!(
|
||||
UPDATE cache SET batch_id = ? WHERE batch_id is NULL
|
||||
);
|
||||
self.connection.exec_bound(q)?(batch.id.as_str())?;
|
||||
|
||||
log::info!("Uploaded batch with {} requests", batch_len);
|
||||
|
||||
Ok(batch.id)
|
||||
}
|
||||
|
||||
fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
model.hash(&mut hasher);
|
||||
max_tokens.hash(&mut hasher);
|
||||
for msg in messages {
|
||||
message_content_to_string(&msg.content).hash(&mut hasher);
|
||||
}
|
||||
let request_hash = hasher.finish();
|
||||
format!("{request_hash:016x}")
|
||||
}
|
||||
}
|
||||
|
||||
fn message_content_to_string(content: &[RequestContent]) -> String {
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
RequestContent::Text { text, .. } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub enum LlmClient {
|
||||
// No batching
|
||||
Plain(PlainLlmClient),
|
||||
Batch(BatchingLlmClient),
|
||||
Dummy,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
|
||||
}
|
||||
|
||||
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(
|
||||
cache_path,
|
||||
http_client,
|
||||
)?))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn dummy() -> Self {
|
||||
Self::Dummy
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
match self {
|
||||
LlmClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
.map(Some),
|
||||
LlmClient::Batch(batching_llm_client) => {
|
||||
batching_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
}
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
LlmClient::Plain(_) => Ok(()),
|
||||
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
pub mod context;
|
||||
pub mod distill;
|
||||
pub mod llm_client;
|
||||
pub mod teacher;
|
||||
@@ -1,48 +0,0 @@
|
||||
# Instructions
|
||||
|
||||
You are a code completion assistant helping a programmer finish their work. Your task is to:
|
||||
|
||||
1. Analyze the edit history to understand what the programmer is trying to achieve
|
||||
2. Identify any incomplete refactoring or changes that need to be finished
|
||||
3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
|
||||
4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
|
||||
|
||||
Focus on:
|
||||
- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
|
||||
- Completing any partially-applied changes across the codebase
|
||||
- Ensuring consistency with the programming style and patterns already established
|
||||
- Making edits that maintain or improve code quality
|
||||
- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
|
||||
- Don't write a lot of code if you're not sure what to do
|
||||
|
||||
Rules:
|
||||
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
|
||||
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
|
||||
|
||||
Input format:
|
||||
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
|
||||
- Never modify the context code.
|
||||
- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
|
||||
- The cursor position is marked with <|user_cursor|>.
|
||||
|
||||
Output format:
|
||||
- Return the entire editable region, applying any edits you make.
|
||||
- Remove the <|user_cursor|> marker.
|
||||
- Wrap the edited code in a block of exactly five backticks.
|
||||
|
||||
Output example:
|
||||
`````
|
||||
// `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
|
||||
if let Some(socket) = &args.askpass {{
|
||||
askpass::main(socket);
|
||||
return Ok(());
|
||||
}}
|
||||
`````
|
||||
|
||||
## User Edits History
|
||||
|
||||
{{edit_history}}
|
||||
|
||||
## Code Context
|
||||
|
||||
{{context}}
|
||||
@@ -1,266 +0,0 @@
|
||||
use crate::{
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::{ContextType, collect_context, strip_special_tags},
|
||||
llm_client::LlmClient,
|
||||
},
|
||||
};
|
||||
use anthropic::{Message, RequestContent, ResponseContent, Role};
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct TeacherModel {
|
||||
pub llm_name: String,
|
||||
pub context: ContextType,
|
||||
pub client: LlmClient,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct TeacherOutput {
|
||||
parsed_output: String,
|
||||
prompt: String,
|
||||
raw_llm_response: String,
|
||||
context: String,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
impl TeacherModel {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
|
||||
|
||||
/// Number of lines to include before the cursor position
|
||||
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Number of lines to include after the cursor position
|
||||
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
|
||||
TeacherModel {
|
||||
llm_name,
|
||||
context,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
|
||||
let name = input.unique_name();
|
||||
let worktree_dir = input.setup_worktree(name).await?;
|
||||
let cursor: SourceLocation = input
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
|
||||
let edit_history = Self::format_edit_history(&input.edit_history);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history);
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = self
|
||||
.client
|
||||
.generate(self.llm_name.clone(), 16384, messages)
|
||||
.await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let response_text = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let parsed_output = self.parse_response(&response_text);
|
||||
|
||||
let original_editable_region = Self::extract_editable_region(&context);
|
||||
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
|
||||
let context_after_edit = strip_special_tags(&context_after_edit);
|
||||
let context_before_edit = strip_special_tags(&context);
|
||||
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
|
||||
|
||||
// zeta distill --batch batch_results.txt
|
||||
// zeta distill
|
||||
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
|
||||
// - store LLM requests in a batch, don't actual send the request
|
||||
// - send the batch (2000 requests) after all inputs are processed
|
||||
// 2. `zeta send-batches`
|
||||
// - upload the batch to Anthropic
|
||||
|
||||
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
|
||||
// https://crates.io/crates/anthropic-sdk-rust
|
||||
|
||||
// - poll for results
|
||||
// - when ready, store results in cache (a database)
|
||||
// 3. `zeta distill` again
|
||||
// - use the cached results this time
|
||||
|
||||
Ok(Some(TeacherOutput {
|
||||
parsed_output,
|
||||
prompt,
|
||||
raw_llm_response: response_text,
|
||||
context,
|
||||
diff,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> String {
|
||||
let codeblock = Self::extract_last_codeblock(content);
|
||||
let editable_region = Self::extract_editable_region(&codeblock);
|
||||
|
||||
editable_region
|
||||
}
|
||||
|
||||
/// Extract content from the last code-fenced block if any, or else return content as is
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::REGION_START)
|
||||
.map_or(0, |pos| pos + Self::REGION_START.len());
|
||||
let end = text.find(Self::REGION_END).unwrap_or(text.len());
|
||||
|
||||
text[start..end].to_string()
|
||||
}
|
||||
|
||||
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn is_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = "This is a test response.";
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, response.to_string());
|
||||
|
||||
let response = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
`````
|
||||
actual response
|
||||
`````
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, "actual response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = TeacherModel::extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -12,32 +12,41 @@ workspace = true
|
||||
path = "src/edit_prediction_context.rs"
|
||||
|
||||
[dependencies]
|
||||
parking_lot.workspace = true
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
hashbrown.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
lsp.workspace = true
|
||||
project.workspace = true
|
||||
log.workspace = true
|
||||
ordered-float.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
slotmap.workspace = true
|
||||
strum.workspace = true
|
||||
text.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
indoc.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = {workspace= true, features = ["test-support"]}
|
||||
serde_json.workspace = true
|
||||
settings = {workspace= true, features = ["test-support"]}
|
||||
text = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-c.workspace = true
|
||||
tree-sitter-cpp.workspace = true
|
||||
tree-sitter-go.workspace = true
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
use crate::RelatedExcerpt;
|
||||
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::ops::Range;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
|
||||
#[cfg(test)]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24;
|
||||
|
||||
pub fn assemble_excerpts(
|
||||
buffer: &BufferSnapshot,
|
||||
mut input_ranges: Vec<Range<Point>>,
|
||||
) -> Vec<RelatedExcerpt> {
|
||||
merge_ranges(&mut input_ranges);
|
||||
|
||||
let mut outline_ranges = Vec::new();
|
||||
let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
|
||||
let mut outline_ix = 0;
|
||||
for input_range in &mut input_ranges {
|
||||
*input_range = clip_range_to_lines(input_range, false, buffer);
|
||||
|
||||
while let Some(outline_item) = outline_items.get(outline_ix) {
|
||||
let item_range = clip_range_to_lines(&outline_item.range, false, buffer);
|
||||
|
||||
if item_range.start > input_range.start {
|
||||
break;
|
||||
}
|
||||
|
||||
if item_range.end > input_range.start {
|
||||
let body_range = outline_item
|
||||
.body_range(buffer)
|
||||
.map(|body| clip_range_to_lines(&body, true, buffer))
|
||||
.filter(|body_range| {
|
||||
body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE
|
||||
});
|
||||
|
||||
add_outline_item(
|
||||
item_range.clone(),
|
||||
body_range.clone(),
|
||||
buffer,
|
||||
&mut outline_ranges,
|
||||
);
|
||||
|
||||
if let Some(body_range) = body_range
|
||||
&& input_range.start < body_range.start
|
||||
{
|
||||
let mut child_outline_ix = outline_ix + 1;
|
||||
while let Some(next_outline_item) = outline_items.get(child_outline_ix) {
|
||||
if next_outline_item.range.end > body_range.end {
|
||||
break;
|
||||
}
|
||||
if next_outline_item.depth == outline_item.depth + 1 {
|
||||
let next_item_range =
|
||||
clip_range_to_lines(&next_outline_item.range, false, buffer);
|
||||
|
||||
add_outline_item(
|
||||
next_item_range,
|
||||
next_outline_item
|
||||
.body_range(buffer)
|
||||
.map(|body| clip_range_to_lines(&body, true, buffer)),
|
||||
buffer,
|
||||
&mut outline_ranges,
|
||||
);
|
||||
}
|
||||
child_outline_ix += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outline_ix += 1;
|
||||
}
|
||||
}
|
||||
|
||||
input_ranges.extend_from_slice(&outline_ranges);
|
||||
merge_ranges(&mut input_ranges);
|
||||
|
||||
input_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let offset_range = range.to_offset(buffer);
|
||||
RelatedExcerpt {
|
||||
point_range: range,
|
||||
anchor_range: buffer.anchor_before(offset_range.start)
|
||||
..buffer.anchor_after(offset_range.end),
|
||||
text: buffer.as_rope().slice(offset_range),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn clip_range_to_lines(
|
||||
range: &Range<Point>,
|
||||
inward: bool,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Range<Point> {
|
||||
let mut range = range.clone();
|
||||
if inward {
|
||||
if range.start.column > 0 {
|
||||
range.start.column = buffer.line_len(range.start.row);
|
||||
}
|
||||
range.end.column = 0;
|
||||
} else {
|
||||
range.start.column = 0;
|
||||
if range.end.column > 0 {
|
||||
range.end.column = buffer.line_len(range.end.row);
|
||||
}
|
||||
}
|
||||
range
|
||||
}
|
||||
|
||||
fn add_outline_item(
|
||||
mut item_range: Range<Point>,
|
||||
body_range: Option<Range<Point>>,
|
||||
buffer: &BufferSnapshot,
|
||||
outline_ranges: &mut Vec<Range<Point>>,
|
||||
) {
|
||||
if let Some(mut body_range) = body_range {
|
||||
if body_range.start.column > 0 {
|
||||
body_range.start.column = buffer.line_len(body_range.start.row);
|
||||
}
|
||||
body_range.end.column = 0;
|
||||
|
||||
let head_range = item_range.start..body_range.start;
|
||||
if head_range.start < head_range.end {
|
||||
outline_ranges.push(head_range);
|
||||
}
|
||||
|
||||
let tail_range = body_range.end..item_range.end;
|
||||
if tail_range.start < tail_range.end {
|
||||
outline_ranges.push(tail_range);
|
||||
}
|
||||
} else {
|
||||
item_range.start.column = 0;
|
||||
item_range.end.column = buffer.line_len(item_range.end.row);
|
||||
outline_ranges.push(item_range);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_ranges(ranges: &mut Vec<Range<Point>>) {
|
||||
ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
|
||||
|
||||
let mut index = 1;
|
||||
while index < ranges.len() {
|
||||
let mut prev_range_end = ranges[index - 1].end;
|
||||
if prev_range_end.column > 0 {
|
||||
prev_range_end += Point::new(1, 0);
|
||||
}
|
||||
|
||||
if (prev_range_end + Point::new(1, 0))
|
||||
.cmp(&ranges[index].start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = ranges.remove(index);
|
||||
if removed.end.cmp(&ranges[index - 1].end).is_gt() {
|
||||
ranges[index - 1].end = removed.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
350
crates/edit_prediction_context/src/declaration.rs
Normal file
350
crates/edit_prediction_context/src/declaration.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
use cloud_llm_client::predict_edits_v3::{self, Line};
|
||||
use language::{Language, LanguageId};
|
||||
use project::ProjectEntryId;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, path::Path};
|
||||
use text::{Bias, BufferId, Rope};
|
||||
use util::paths::{path_ends_with, strip_path_suffix};
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
use crate::outline::OutlineDeclaration;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Identifier {
|
||||
pub name: Arc<str>,
|
||||
pub language_id: LanguageId,
|
||||
}
|
||||
|
||||
slotmap::new_key_type! {
|
||||
pub struct DeclarationId;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Declaration {
|
||||
File {
|
||||
project_entry_id: ProjectEntryId,
|
||||
declaration: FileDeclaration,
|
||||
cached_path: CachedDeclarationPath,
|
||||
},
|
||||
Buffer {
|
||||
project_entry_id: ProjectEntryId,
|
||||
buffer_id: BufferId,
|
||||
rope: Rope,
|
||||
declaration: BufferDeclaration,
|
||||
cached_path: CachedDeclarationPath,
|
||||
},
|
||||
}
|
||||
|
||||
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
|
||||
|
||||
impl Declaration {
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => &declaration.identifier,
|
||||
Declaration::Buffer { declaration, .. } => &declaration.identifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parent(&self) -> Option<DeclarationId> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.parent,
|
||||
Declaration::Buffer { declaration, .. } => declaration.parent,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
|
||||
match self {
|
||||
Declaration::File { .. } => None,
|
||||
Declaration::Buffer { declaration, .. } => Some(declaration),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_file(&self) -> Option<&FileDeclaration> {
|
||||
match self {
|
||||
Declaration::Buffer { .. } => None,
|
||||
Declaration::File { declaration, .. } => Some(declaration),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project_entry_id(&self) -> ProjectEntryId {
|
||||
match self {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => *project_entry_id,
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => *project_entry_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cached_path(&self) -> &CachedDeclarationPath {
|
||||
match self {
|
||||
Declaration::File { cached_path, .. } => cached_path,
|
||||
Declaration::Buffer { cached_path, .. } => cached_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_range(&self) -> Range<usize> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.item_range.clone(),
|
||||
Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_line_range(&self) -> Range<Line> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
|
||||
Declaration::Buffer {
|
||||
declaration, rope, ..
|
||||
} => {
|
||||
Line(rope.offset_to_point(declaration.item_range.start).row)
|
||||
..Line(rope.offset_to_point(declaration.item_range.end).row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text.as_ref().into(),
|
||||
declaration.text_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.item_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.item_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text[self.signature_range_in_item_text()].into(),
|
||||
declaration.signature_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.signature_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.signature_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_range(&self) -> Range<usize> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.signature_range.clone(),
|
||||
Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_line_range(&self) -> Range<Line> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
|
||||
Declaration::Buffer {
|
||||
declaration, rope, ..
|
||||
} => {
|
||||
Line(rope.offset_to_point(declaration.signature_range.start).row)
|
||||
..Line(rope.offset_to_point(declaration.signature_range.end).row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_range_in_item_text(&self) -> Range<usize> {
|
||||
let signature_range = self.signature_range();
|
||||
let item_range = self.item_range();
|
||||
signature_range.start.saturating_sub(item_range.start)
|
||||
..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_range_to_line_boundaries_and_truncate(
|
||||
range: &Range<usize>,
|
||||
limit: usize,
|
||||
rope: &Rope,
|
||||
) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
|
||||
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
|
||||
point_range.start.column = 0;
|
||||
point_range.end.row += 1;
|
||||
point_range.end.column = 0;
|
||||
|
||||
let mut item_range =
|
||||
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
|
||||
let is_truncated = item_range.len() > limit;
|
||||
if is_truncated {
|
||||
item_range.end = item_range.start + limit;
|
||||
}
|
||||
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
|
||||
|
||||
let line_range =
|
||||
predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
|
||||
(item_range, line_range, is_truncated)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
/// offset range of the declaration in the file, expanded to line boundaries and truncated
|
||||
pub item_range: Range<usize>,
|
||||
/// line range of the declaration in the file, potentially truncated
|
||||
pub item_line_range: Range<predict_edits_v3::Line>,
|
||||
/// text of `item_range`
|
||||
pub text: Arc<str>,
|
||||
/// whether `text` was truncated
|
||||
pub text_is_truncated: bool,
|
||||
/// offset range of the signature in the file, expanded to line boundaries and truncated
|
||||
pub signature_range: Range<usize>,
|
||||
/// line range of the signature in the file, truncated
|
||||
pub signature_line_range: Range<Line>,
|
||||
/// whether `signature` was truncated
|
||||
pub signature_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl FileDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
|
||||
let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
if signature_range_in_file.start < item_range_in_file.start {
|
||||
signature_range_in_file.start = item_range_in_file.start;
|
||||
signature_is_truncated = true;
|
||||
}
|
||||
if signature_range_in_file.end > item_range_in_file.end {
|
||||
signature_range_in_file.end = item_range_in_file.end;
|
||||
signature_is_truncated = true;
|
||||
}
|
||||
|
||||
FileDeclaration {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
signature_range: signature_range_in_file,
|
||||
signature_line_range,
|
||||
signature_is_truncated,
|
||||
text: rope
|
||||
.chunks_in_range(item_range_in_file.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
text_is_truncated,
|
||||
item_range: item_range_in_file,
|
||||
item_line_range: item_line_range_in_file,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BufferDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub item_range_is_truncated: bool,
|
||||
pub signature_range: Range<usize>,
|
||||
pub signature_range_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl BufferDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
|
||||
let (item_range, _item_line_range, item_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
let (signature_range, _signature_line_range, signature_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
Self {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range,
|
||||
item_range_is_truncated,
|
||||
signature_range,
|
||||
signature_range_is_truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedDeclarationPath {
|
||||
pub worktree_abs_path: Arc<Path>,
|
||||
pub rel_path: Arc<RelPath>,
|
||||
/// The relative path of the file, possibly stripped according to `import_path_strip_regex`.
|
||||
pub rel_path_after_regex_stripping: Arc<RelPath>,
|
||||
}
|
||||
|
||||
impl CachedDeclarationPath {
|
||||
pub fn new(
|
||||
worktree_abs_path: Arc<Path>,
|
||||
path: &Arc<RelPath>,
|
||||
language: Option<&Arc<Language>>,
|
||||
) -> Self {
|
||||
let rel_path = path.clone();
|
||||
let rel_path_after_regex_stripping = if let Some(language) = language
|
||||
&& let Some(strip_regex) = language.config().import_path_strip_regex.as_ref()
|
||||
&& let Ok(stripped) = RelPath::unix(&Path::new(
|
||||
strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(),
|
||||
)) {
|
||||
Arc::from(stripped)
|
||||
} else {
|
||||
rel_path.clone()
|
||||
};
|
||||
CachedDeclarationPath {
|
||||
worktree_abs_path,
|
||||
rel_path,
|
||||
rel_path_after_regex_stripping,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self {
|
||||
let rel_path: Arc<RelPath> = util::rel_path::rel_path(rel_path).into();
|
||||
CachedDeclarationPath {
|
||||
worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(),
|
||||
rel_path_after_regex_stripping: rel_path.clone(),
|
||||
rel_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ends_with_posix_path(&self, path: &Path) -> bool {
|
||||
if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() {
|
||||
path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path)
|
||||
} else {
|
||||
if let Some(remaining) =
|
||||
strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path())
|
||||
{
|
||||
path_ends_with(&self.worktree_abs_path, remaining)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn equals_absolute_path(&self, path: &Path) -> bool {
|
||||
if let Some(remaining) =
|
||||
strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path())
|
||||
{
|
||||
self.worktree_abs_path.as_ref() == remaining
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
539
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
539
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
@@ -0,0 +1,539 @@
|
||||
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
|
||||
use collections::HashMap;
|
||||
use language::BufferSnapshot;
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::ProjectEntryId;
|
||||
use serde::Serialize;
|
||||
use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
|
||||
use strum::EnumIter;
|
||||
use text::{Point, ToPoint};
|
||||
use util::RangeExt as _;
|
||||
|
||||
use crate::{
|
||||
CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
|
||||
imports::{Import, Imports, Module},
|
||||
reference::{Reference, ReferenceRegion},
|
||||
syntax_index::SyntaxIndexState,
|
||||
text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
};
|
||||
|
||||
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct EditPredictionScoreOptions {
|
||||
pub omit_excerpt_overlaps: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ScoredDeclaration {
|
||||
/// identifier used by the local reference
|
||||
pub identifier: Identifier,
|
||||
pub declaration: Declaration,
|
||||
pub components: DeclarationScoreComponents,
|
||||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Default)]
|
||||
pub struct DeclarationScores {
|
||||
pub signature: f32,
|
||||
pub declaration: f32,
|
||||
pub retrieval: f32,
|
||||
}
|
||||
|
||||
impl ScoredDeclaration {
|
||||
/// Returns the score for this declaration with the specified style.
|
||||
pub fn score(&self, style: DeclarationStyle) -> f32 {
|
||||
// TODO: handle truncation
|
||||
|
||||
// Score related to how likely this is the correct declaration, range 0 to 1
|
||||
let retrieval = self.retrieval_score();
|
||||
|
||||
// Score related to the distance between the reference and cursor, range 0 to 1
|
||||
let distance_score = if self.components.is_referenced_nearby {
|
||||
1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
|
||||
} else {
|
||||
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
|
||||
0.5
|
||||
};
|
||||
|
||||
// For now instead of linear combination, the scores are just multiplied together.
|
||||
let combined_score = 10.0 * retrieval * distance_score;
|
||||
|
||||
match style {
|
||||
DeclarationStyle::Signature => {
|
||||
combined_score * self.components.excerpt_vs_signature_weighted_overlap
|
||||
}
|
||||
DeclarationStyle::Declaration => {
|
||||
2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieval_score(&self) -> f32 {
|
||||
let mut score = if self.components.is_same_file {
|
||||
10.0 / self.components.same_file_declaration_count as f32
|
||||
} else if self.components.path_import_match_count > 0 {
|
||||
3.0
|
||||
} else if self.components.wildcard_path_import_match_count > 0 {
|
||||
1.0
|
||||
} else if self.components.normalized_import_similarity > 0.0 {
|
||||
self.components.normalized_import_similarity
|
||||
} else if self.components.normalized_wildcard_import_similarity > 0.0 {
|
||||
0.5 * self.components.normalized_wildcard_import_similarity
|
||||
} else {
|
||||
1.0 / self.components.declaration_count as f32
|
||||
};
|
||||
score *= 1. + self.components.included_by_others as f32 / 2.;
|
||||
score *= 1. + self.components.includes_others as f32 / 4.;
|
||||
score
|
||||
}
|
||||
|
||||
pub fn size(&self, style: DeclarationStyle) -> usize {
|
||||
match &self.declaration {
|
||||
Declaration::File { declaration, .. } => match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
},
|
||||
Declaration::Buffer { declaration, .. } => match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.item_range.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn score_density(&self, style: DeclarationStyle) -> f32 {
|
||||
self.score(style) / self.size(style) as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scored_declarations(
|
||||
options: &EditPredictionScoreOptions,
|
||||
index: &SyntaxIndexState,
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
imports: &Imports,
|
||||
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
|
||||
cursor_offset: usize,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Vec<ScoredDeclaration> {
|
||||
let cursor_point = cursor_offset.to_point(¤t_buffer);
|
||||
|
||||
let mut wildcard_import_occurrences = Vec::new();
|
||||
let mut wildcard_import_paths = Vec::new();
|
||||
for wildcard_import in imports.wildcard_modules.iter() {
|
||||
match wildcard_import {
|
||||
Module::Namespace(namespace) => {
|
||||
wildcard_import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => wildcard_import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
wildcard_import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut scored_declarations = Vec::new();
|
||||
let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
|
||||
HashMap::default();
|
||||
for (identifier, references) in identifier_to_references {
|
||||
let mut import_occurrences = Vec::new();
|
||||
let mut import_paths = Vec::new();
|
||||
let mut found_external_identifier: Option<&Identifier> = None;
|
||||
|
||||
if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
|
||||
// only use alias when it's the only import, could be generalized if some language
|
||||
// has overlapping aliases
|
||||
//
|
||||
// TODO: when an aliased declaration is included in the prompt, should include the
|
||||
// aliasing in the prompt.
|
||||
//
|
||||
// TODO: For SourceFuzzy consider having componentwise comparison that pays
|
||||
// attention to ordering.
|
||||
if let [
|
||||
Import::Alias {
|
||||
module,
|
||||
external_identifier,
|
||||
},
|
||||
] = imports.as_slice()
|
||||
{
|
||||
match module {
|
||||
Module::Namespace(namespace) => {
|
||||
import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
}
|
||||
found_external_identifier = Some(&external_identifier);
|
||||
} else {
|
||||
for import in imports {
|
||||
match import {
|
||||
Import::Direct { module } => match module {
|
||||
Module::Namespace(namespace) => {
|
||||
import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
},
|
||||
Import::Alias { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
|
||||
// TODO: update this to be able to return more declarations? Especially if there is the
|
||||
// ability to quickly filter a large list (based on imports)
|
||||
let identifier_declarations = index
|
||||
.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
|
||||
let declaration_count = identifier_declarations.len();
|
||||
|
||||
if declaration_count == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: option to filter out other candidates when same file / import match
|
||||
let mut checked_declarations = Vec::with_capacity(declaration_count);
|
||||
for (declaration_id, declaration) in identifier_declarations {
|
||||
match declaration {
|
||||
Declaration::Buffer {
|
||||
buffer_id,
|
||||
declaration: buffer_declaration,
|
||||
..
|
||||
} => {
|
||||
if buffer_id == ¤t_buffer.remote_id() {
|
||||
let already_included_in_prompt =
|
||||
range_intersection(&buffer_declaration.item_range, &excerpt.range)
|
||||
.is_some()
|
||||
|| excerpt
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
|
||||
if !options.omit_excerpt_overlaps || !already_included_in_prompt {
|
||||
let declaration_line = buffer_declaration
|
||||
.item_range
|
||||
.start
|
||||
.to_point(current_buffer)
|
||||
.row;
|
||||
let declaration_line_distance =
|
||||
(cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
|
||||
checked_declarations.push(CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance: Some(declaration_line_distance),
|
||||
path_import_match_count: 0,
|
||||
wildcard_path_import_match_count: 0,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
Declaration::File { .. } => {}
|
||||
}
|
||||
let declaration_path = declaration.cached_path();
|
||||
let path_import_match_count = import_paths
|
||||
.iter()
|
||||
.filter(|import_path| {
|
||||
declaration_path_matches_import(&declaration_path, import_path)
|
||||
})
|
||||
.count();
|
||||
let wildcard_path_import_match_count = wildcard_import_paths
|
||||
.iter()
|
||||
.filter(|import_path| {
|
||||
declaration_path_matches_import(&declaration_path, import_path)
|
||||
})
|
||||
.count();
|
||||
checked_declarations.push(CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance: None,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
});
|
||||
}
|
||||
|
||||
let mut max_import_similarity = 0.0;
|
||||
let mut max_wildcard_import_similarity = 0.0;
|
||||
|
||||
let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
|
||||
for checked_declaration in checked_declarations {
|
||||
let same_file_declaration_count =
|
||||
index.file_declaration_count(checked_declaration.declaration);
|
||||
|
||||
let declaration = score_declaration(
|
||||
&identifier,
|
||||
&references,
|
||||
checked_declaration,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
&import_occurrences,
|
||||
&wildcard_import_occurrences,
|
||||
cursor_point,
|
||||
current_buffer,
|
||||
);
|
||||
|
||||
if declaration.components.import_similarity > max_import_similarity {
|
||||
max_import_similarity = declaration.components.import_similarity;
|
||||
}
|
||||
|
||||
if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
|
||||
max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
|
||||
}
|
||||
|
||||
project_entry_id_to_outline_ranges
|
||||
.entry(declaration.declaration.project_entry_id())
|
||||
.or_default()
|
||||
.push(declaration.declaration.item_range());
|
||||
scored_declarations_for_identifier.push(declaration);
|
||||
}
|
||||
|
||||
if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
|
||||
for declaration in scored_declarations_for_identifier.iter_mut() {
|
||||
if max_import_similarity > 0.0 {
|
||||
declaration.components.max_import_similarity = max_import_similarity;
|
||||
declaration.components.normalized_import_similarity =
|
||||
declaration.components.import_similarity / max_import_similarity;
|
||||
}
|
||||
if max_wildcard_import_similarity > 0.0 {
|
||||
declaration.components.normalized_wildcard_import_similarity =
|
||||
declaration.components.wildcard_import_similarity
|
||||
/ max_wildcard_import_similarity;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scored_declarations.extend(scored_declarations_for_identifier);
|
||||
}
|
||||
|
||||
// TODO: Inform this via import / retrieval scores of outline items
|
||||
// TODO: Consider using a sweepline
|
||||
for scored_declaration in scored_declarations.iter_mut() {
|
||||
let project_entry_id = scored_declaration.declaration.project_entry_id();
|
||||
let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
|
||||
continue;
|
||||
};
|
||||
for range in ranges {
|
||||
if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
|
||||
scored_declaration.components.included_by_others += 1
|
||||
} else if scored_declaration
|
||||
.declaration
|
||||
.item_range()
|
||||
.contains_inclusive(range)
|
||||
{
|
||||
scored_declaration.components.includes_others += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scored_declarations.sort_unstable_by_key(|declaration| {
|
||||
Reverse(OrderedFloat(
|
||||
declaration.score(DeclarationStyle::Declaration),
|
||||
))
|
||||
});
|
||||
|
||||
scored_declarations
|
||||
}
|
||||
|
||||
struct CheckedDeclaration<'a> {
|
||||
declaration: &'a Declaration,
|
||||
same_file_line_distance: Option<u32>,
|
||||
path_import_match_count: usize,
|
||||
wildcard_path_import_match_count: usize,
|
||||
}
|
||||
|
||||
fn declaration_path_matches_import(
|
||||
declaration_path: &CachedDeclarationPath,
|
||||
import_path: &Arc<Path>,
|
||||
) -> bool {
|
||||
if import_path.is_absolute() {
|
||||
declaration_path.equals_absolute_path(import_path)
|
||||
} else {
|
||||
declaration_path.ends_with_posix_path(import_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
|
||||
let start = a.start.clone().max(b.start.clone());
|
||||
let end = a.end.clone().min(b.end.clone());
|
||||
if start < end {
|
||||
Some(Range { start, end })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn score_declaration(
|
||||
identifier: &Identifier,
|
||||
references: &[Reference],
|
||||
checked_declaration: CheckedDeclaration,
|
||||
same_file_declaration_count: usize,
|
||||
declaration_count: usize,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
import_occurrences: &[Occurrences],
|
||||
wildcard_import_occurrences: &[Occurrences],
|
||||
cursor: Point,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> ScoredDeclaration {
|
||||
let CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
} = checked_declaration;
|
||||
|
||||
let is_referenced_nearby = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Nearby);
|
||||
let is_referenced_in_breadcrumb = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Breadcrumb);
|
||||
let reference_count = references.len();
|
||||
let reference_line_distance = references
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let reference_line = r.range.start.to_point(current_buffer).row as i32;
|
||||
(cursor.row as i32 - reference_line).unsigned_abs()
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let is_same_file = same_file_line_distance.is_some();
|
||||
let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
|
||||
|
||||
let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
|
||||
let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_jaccard =
|
||||
jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_jaccard =
|
||||
jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_jaccard =
|
||||
jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
let excerpt_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
let mut import_similarity = 0f32;
|
||||
let mut wildcard_import_similarity = 0f32;
|
||||
if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
|
||||
let cached_path = declaration.cached_path();
|
||||
let path_occurrences = Occurrences::from_worktree_path(
|
||||
cached_path
|
||||
.worktree_abs_path
|
||||
.file_name()
|
||||
.map(|f| f.to_string_lossy()),
|
||||
&cached_path.rel_path,
|
||||
);
|
||||
import_similarity = import_occurrences
|
||||
.iter()
|
||||
.map(|namespace_occurrences| {
|
||||
OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
|
||||
})
|
||||
.max()
|
||||
.map(|similarity| similarity.into_inner())
|
||||
.unwrap_or_default();
|
||||
|
||||
// TODO: Consider something other than max
|
||||
wildcard_import_similarity = wildcard_import_occurrences
|
||||
.iter()
|
||||
.map(|namespace_occurrences| {
|
||||
OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
|
||||
})
|
||||
.max()
|
||||
.map(|similarity| similarity.into_inner())
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
// TODO: Consider adding declaration_file_count
|
||||
let score_components = DeclarationScoreComponents {
|
||||
is_same_file,
|
||||
is_referenced_nearby,
|
||||
is_referenced_in_breadcrumb,
|
||||
reference_line_distance,
|
||||
declaration_line_distance,
|
||||
reference_count,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
excerpt_vs_item_jaccard,
|
||||
excerpt_vs_signature_jaccard,
|
||||
adjacent_vs_item_jaccard,
|
||||
adjacent_vs_signature_jaccard,
|
||||
excerpt_vs_item_weighted_overlap,
|
||||
excerpt_vs_signature_weighted_overlap,
|
||||
adjacent_vs_item_weighted_overlap,
|
||||
adjacent_vs_signature_weighted_overlap,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
import_similarity,
|
||||
max_import_similarity: 0.0,
|
||||
normalized_import_similarity: 0.0,
|
||||
wildcard_import_similarity,
|
||||
normalized_wildcard_import_similarity: 0.0,
|
||||
included_by_others: 0,
|
||||
includes_others: 0,
|
||||
};
|
||||
|
||||
ScoredDeclaration {
|
||||
identifier: identifier.clone(),
|
||||
declaration: declaration.clone(),
|
||||
components: score_components,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_declaration_path_matches() {
|
||||
let declaration_path =
|
||||
CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("user/project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("/home/user/project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(!declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("other.ts").into()
|
||||
));
|
||||
|
||||
assert!(!declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("/home/user/project/src/other.ts").into()
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -1,490 +1,335 @@
|
||||
use crate::assemble_excerpts::assemble_excerpts;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
|
||||
use project::{LocationLink, Project, ProjectPath};
|
||||
use serde::{Serialize, Serializer};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
ops::Range,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::{RangeExt as _, ResultExt};
|
||||
|
||||
mod assemble_excerpts;
|
||||
#[cfg(test)]
|
||||
mod edit_prediction_context_tests;
|
||||
mod declaration;
|
||||
mod declaration_scoring;
|
||||
mod excerpt;
|
||||
#[cfg(test)]
|
||||
mod fake_definition_lsp;
|
||||
mod imports;
|
||||
mod outline;
|
||||
mod reference;
|
||||
mod syntax_index;
|
||||
pub mod text_similarity;
|
||||
|
||||
pub use cloud_llm_client::predict_edits_v3::Line;
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
const IDENTIFIER_LINE_COUNT: u32 = 3;
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use language::BufferSnapshot;
|
||||
use text::{Point, ToOffset as _};
|
||||
|
||||
pub struct RelatedExcerptStore {
|
||||
project: WeakEntity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
cache: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
|
||||
identifier_line_count: u32,
|
||||
}
|
||||
pub use declaration::*;
|
||||
pub use declaration_scoring::*;
|
||||
pub use excerpt::*;
|
||||
pub use imports::*;
|
||||
pub use reference::*;
|
||||
pub use syntax_index::*;
|
||||
|
||||
pub enum RelatedExcerptStoreEvent {
|
||||
StartedRefresh,
|
||||
FinishedRefresh {
|
||||
cache_hit_count: usize,
|
||||
cache_miss_count: usize,
|
||||
mean_definition_latency: Duration,
|
||||
max_definition_latency: Duration,
|
||||
},
|
||||
}
|
||||
pub use predict_edits_v3::Line;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
struct Identifier {
|
||||
pub name: String,
|
||||
pub range: Range<Anchor>,
|
||||
}
|
||||
|
||||
enum DefinitionTask {
|
||||
CacheHit(Arc<CacheEntry>),
|
||||
CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CacheEntry {
|
||||
definitions: SmallVec<[CachedDefinition; 1]>,
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct EditPredictionContextOptions {
|
||||
pub use_imports: bool,
|
||||
pub excerpt: EditPredictionExcerptOptions,
|
||||
pub score: EditPredictionScoreOptions,
|
||||
pub max_retrieved_declarations: u8,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CachedDefinition {
|
||||
path: ProjectPath,
|
||||
buffer: Entity<Buffer>,
|
||||
anchor_range: Range<Anchor>,
|
||||
pub struct EditPredictionContext {
|
||||
pub excerpt: EditPredictionExcerpt,
|
||||
pub excerpt_text: EditPredictionExcerptText,
|
||||
pub cursor_point: Point,
|
||||
pub declarations: Vec<ScoredDeclaration>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedFile {
|
||||
#[serde(serialize_with = "serialize_project_path")]
|
||||
pub path: ProjectPath,
|
||||
#[serde(skip)]
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
pub max_row: u32,
|
||||
}
|
||||
|
||||
impl RelatedFile {
|
||||
pub fn merge_excerpts(&mut self) {
|
||||
self.excerpts.sort_unstable_by(|a, b| {
|
||||
a.point_range
|
||||
.start
|
||||
.cmp(&b.point_range.start)
|
||||
.then(b.point_range.end.cmp(&a.point_range.end))
|
||||
impl EditPredictionContext {
|
||||
pub fn gather_context_in_background(
|
||||
cursor_point: Point,
|
||||
buffer: BufferSnapshot,
|
||||
options: EditPredictionContextOptions,
|
||||
syntax_index: Option<Entity<SyntaxIndex>>,
|
||||
cx: &mut App,
|
||||
) -> Task<Option<Self>> {
|
||||
let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
|
||||
let mut path = f.worktree.read(cx).absolutize(&f.path);
|
||||
if path.pop() { Some(path) } else { None }
|
||||
});
|
||||
|
||||
let mut index = 1;
|
||||
while index < self.excerpts.len() {
|
||||
if self.excerpts[index - 1]
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index].point_range.start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = self.excerpts.remove(index);
|
||||
if removed
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index - 1].point_range.end)
|
||||
.is_gt()
|
||||
{
|
||||
self.excerpts[index - 1].point_range.end = removed.point_range.end;
|
||||
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
#[serde(skip)]
|
||||
pub anchor_range: Range<Anchor>,
|
||||
#[serde(serialize_with = "serialize_point_range")]
|
||||
pub point_range: Range<Point>,
|
||||
#[serde(serialize_with = "serialize_rope")]
|
||||
pub text: Rope,
|
||||
}
|
||||
|
||||
fn serialize_project_path<S: Serializer>(
|
||||
project_path: &ProjectPath,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
project_path.path.serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
rope.to_string().serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_point_range<S: Serializer>(
|
||||
range: &Range<Point>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
[
|
||||
[range.start.row, range.start.column],
|
||||
[range.end.row, range.end.column],
|
||||
]
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
|
||||
|
||||
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
|
||||
|
||||
impl RelatedExcerptStore {
|
||||
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let executor = cx.background_executor().clone();
|
||||
while let Some((mut buffer, mut position)) = update_rx.next().await {
|
||||
let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
|
||||
loop {
|
||||
futures::select_biased! {
|
||||
next = update_rx.next() => {
|
||||
if let Some((new_buffer, new_position)) = next {
|
||||
buffer = new_buffer;
|
||||
position = new_position;
|
||||
timer = executor.timer(DEBOUNCE_DURATION).fuse();
|
||||
} else {
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
}
|
||||
_ = timer => break,
|
||||
}
|
||||
}
|
||||
|
||||
Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
RelatedExcerptStore {
|
||||
project: project.downgrade(),
|
||||
update_tx,
|
||||
related_files: Vec::new(),
|
||||
cache: Default::default(),
|
||||
identifier_line_count: IDENTIFIER_LINE_COUNT,
|
||||
if let Some(syntax_index) = syntax_index {
|
||||
let index_state =
|
||||
syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
|
||||
cx.background_spawn(async move {
|
||||
let parent_abs_path = parent_abs_path.as_deref();
|
||||
let index_state = index_state.upgrade()?;
|
||||
let index_state = index_state.lock().await;
|
||||
Self::gather_context(
|
||||
cursor_point,
|
||||
&buffer,
|
||||
parent_abs_path,
|
||||
&options,
|
||||
Some(&index_state),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
cx.background_spawn(async move {
|
||||
let parent_abs_path = parent_abs_path.as_deref();
|
||||
Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_identifier_line_count(&mut self, count: u32) {
|
||||
self.identifier_line_count = count;
|
||||
pub fn gather_context(
|
||||
cursor_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
parent_abs_path: Option<&Path>,
|
||||
options: &EditPredictionContextOptions,
|
||||
index_state: Option<&SyntaxIndexState>,
|
||||
) -> Option<Self> {
|
||||
let imports = if options.use_imports {
|
||||
Imports::gather(&buffer, parent_abs_path)
|
||||
} else {
|
||||
Imports::default()
|
||||
};
|
||||
Self::gather_context_with_references_fn(
|
||||
cursor_point,
|
||||
buffer,
|
||||
&imports,
|
||||
options,
|
||||
index_state,
|
||||
references_in_excerpt,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
|
||||
self.update_tx.unbounded_send((buffer, position)).ok();
|
||||
}
|
||||
pub fn gather_context_with_references_fn(
|
||||
cursor_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
imports: &Imports,
|
||||
options: &EditPredictionContextOptions,
|
||||
index_state: Option<&SyntaxIndexState>,
|
||||
get_references: impl FnOnce(
|
||||
&EditPredictionExcerpt,
|
||||
&EditPredictionExcerptText,
|
||||
&BufferSnapshot,
|
||||
) -> HashMap<Identifier, Vec<Reference>>,
|
||||
) -> Option<Self> {
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
buffer,
|
||||
&options.excerpt,
|
||||
index_state,
|
||||
)?;
|
||||
let excerpt_text = excerpt.text(buffer);
|
||||
|
||||
pub fn related_files(&self) -> &[RelatedFile] {
|
||||
&self.related_files
|
||||
}
|
||||
let declarations = if options.max_retrieved_declarations > 0
|
||||
&& let Some(index_state) = index_state
|
||||
{
|
||||
let excerpt_occurrences =
|
||||
text_similarity::Occurrences::within_string(&excerpt_text.body);
|
||||
|
||||
async fn fetch_excerpts(
|
||||
this: WeakEntity<Self>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: Anchor,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
|
||||
(
|
||||
this.project.upgrade(),
|
||||
buffer.read(cx).snapshot(),
|
||||
this.identifier_line_count,
|
||||
)
|
||||
})?;
|
||||
let Some(project) = project else {
|
||||
return Ok(());
|
||||
let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let adjacent_end = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_occurrences = text_similarity::Occurrences::within_string(
|
||||
&buffer
|
||||
.text_for_range(adjacent_start..adjacent_end)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let cursor_offset_in_file = cursor_point.to_offset(buffer);
|
||||
|
||||
let references = get_references(&excerpt, &excerpt_text, buffer);
|
||||
|
||||
let mut declarations = scored_declarations(
|
||||
&options.score,
|
||||
&index_state,
|
||||
&excerpt,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
&imports,
|
||||
references,
|
||||
cursor_offset_in_file,
|
||||
buffer,
|
||||
);
|
||||
// TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
|
||||
declarations.truncate(options.max_retrieved_declarations as usize);
|
||||
declarations
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let file = snapshot.file().cloned();
|
||||
if let Some(file) = &file {
|
||||
log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
|
||||
}
|
||||
|
||||
this.update(cx, |_, cx| {
|
||||
cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
|
||||
})?;
|
||||
|
||||
let identifiers = cx
|
||||
.background_spawn(async move {
|
||||
identifiers_for_position(&snapshot, position, identifier_line_count)
|
||||
})
|
||||
.await;
|
||||
|
||||
let async_cx = cx.clone();
|
||||
let start_time = Instant::now();
|
||||
let futures = this.update(cx, |this, cx| {
|
||||
identifiers
|
||||
.into_iter()
|
||||
.filter_map(|identifier| {
|
||||
let task = if let Some(entry) = this.cache.get(&identifier) {
|
||||
DefinitionTask::CacheHit(entry.clone())
|
||||
} else {
|
||||
DefinitionTask::CacheMiss(
|
||||
this.project
|
||||
.update(cx, |project, cx| {
|
||||
project.definitions(&buffer, identifier.range.start, cx)
|
||||
})
|
||||
.ok()?,
|
||||
)
|
||||
};
|
||||
|
||||
let cx = async_cx.clone();
|
||||
let project = project.clone();
|
||||
Some(async move {
|
||||
match task {
|
||||
DefinitionTask::CacheHit(cache_entry) => {
|
||||
Some((identifier, cache_entry, None))
|
||||
}
|
||||
DefinitionTask::CacheMiss(task) => {
|
||||
let locations = task.await.log_err()??;
|
||||
let duration = start_time.elapsed();
|
||||
cx.update(|cx| {
|
||||
(
|
||||
identifier,
|
||||
Arc::new(CacheEntry {
|
||||
definitions: locations
|
||||
.into_iter()
|
||||
.filter_map(|location| {
|
||||
process_definition(location, &project, cx)
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
Some(duration),
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})?;
|
||||
|
||||
let mut cache_hit_count = 0;
|
||||
let mut cache_miss_count = 0;
|
||||
let mut mean_definition_latency = Duration::ZERO;
|
||||
let mut max_definition_latency = Duration::ZERO;
|
||||
let mut new_cache = HashMap::default();
|
||||
new_cache.reserve(futures.len());
|
||||
for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
|
||||
new_cache.insert(identifier, entry);
|
||||
if let Some(duration) = duration {
|
||||
cache_miss_count += 1;
|
||||
mean_definition_latency += duration;
|
||||
max_definition_latency = max_definition_latency.max(duration);
|
||||
} else {
|
||||
cache_hit_count += 1;
|
||||
}
|
||||
}
|
||||
mean_definition_latency /= cache_miss_count.max(1) as u32;
|
||||
|
||||
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
|
||||
|
||||
if let Some(file) = &file {
|
||||
log::debug!(
|
||||
"finished retrieving context buffer:{}, latency:{:?}",
|
||||
file.path().as_unix_str(),
|
||||
start_time.elapsed()
|
||||
);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.cache = new_cache;
|
||||
this.related_files = related_files;
|
||||
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
});
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn rebuild_related_files(
|
||||
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
|
||||
let mut snapshots = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
|
||||
definition
|
||||
.buffer
|
||||
.read_with(cx, |buffer, _| buffer.parsing_idle())?
|
||||
.await;
|
||||
e.insert(
|
||||
definition
|
||||
.buffer
|
||||
.read_with(cx, |buffer, _| buffer.snapshot())?,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cx
|
||||
.background_spawn(async move {
|
||||
let mut files = Vec::<RelatedFile>::new();
|
||||
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
|
||||
let mut paths_by_buffer = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
|
||||
ranges_by_buffer
|
||||
.entry(definition.buffer.clone())
|
||||
.or_default()
|
||||
.push(definition.anchor_range.to_point(snapshot));
|
||||
}
|
||||
}
|
||||
|
||||
for (buffer, ranges) in ranges_by_buffer {
|
||||
let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
let excerpts = assemble_excerpts(snapshot, ranges);
|
||||
files.push(RelatedFile {
|
||||
path: project_path.clone(),
|
||||
buffer: buffer.downgrade(),
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
});
|
||||
}
|
||||
|
||||
files.sort_by_key(|file| file.path.clone());
|
||||
(new_entries, files)
|
||||
Some(Self {
|
||||
excerpt,
|
||||
excerpt_text,
|
||||
cursor_point,
|
||||
declarations,
|
||||
})
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
fn process_definition(
|
||||
location: LocationLink,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<CachedDefinition> {
|
||||
let buffer = location.target.buffer.read(cx);
|
||||
let anchor_range = location.target.range;
|
||||
let file = buffer.file()?;
|
||||
let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
|
||||
if worktree.read(cx).is_single_file() {
|
||||
return None;
|
||||
}
|
||||
Some(CachedDefinition {
|
||||
path: ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path().clone(),
|
||||
},
|
||||
buffer: location.target.buffer,
|
||||
anchor_range,
|
||||
})
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Gets all of the identifiers that are present in the given line, and its containing
|
||||
/// outline items.
|
||||
fn identifiers_for_position(
|
||||
buffer: &BufferSnapshot,
|
||||
position: Anchor,
|
||||
identifier_line_count: u32,
|
||||
) -> Vec<Identifier> {
|
||||
let offset = position.to_offset(buffer);
|
||||
let point = buffer.offset_to_point(offset);
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
// Search for identifiers on lines adjacent to the cursor.
|
||||
let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
|
||||
let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
|
||||
let line_range = start..end;
|
||||
let mut ranges = vec![line_range.to_offset(&buffer)];
|
||||
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
|
||||
|
||||
// Search for identifiers mentioned in headers/signatures of containing outline items.
|
||||
let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
|
||||
for item in outline_items {
|
||||
if let Some(body_range) = item.body_range(&buffer) {
|
||||
ranges.push(item.range.start..body_range.start.to_offset(&buffer));
|
||||
} else {
|
||||
ranges.push(item.range.clone());
|
||||
}
|
||||
#[gpui::test]
|
||||
async fn test_call_site(cx: &mut TestAppContext) {
|
||||
let (project, index, _rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// first process_data call site
|
||||
let cursor_point = language::Point::new(8, 21);
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let context = cx
|
||||
.update(|cx| {
|
||||
EditPredictionContext::gather_context_in_background(
|
||||
cursor_point,
|
||||
buffer_snapshot,
|
||||
EditPredictionContextOptions {
|
||||
use_imports: true,
|
||||
excerpt: EditPredictionExcerptOptions {
|
||||
max_bytes: 60,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
score: EditPredictionScoreOptions {
|
||||
omit_excerpt_overlaps: true,
|
||||
},
|
||||
max_retrieved_declarations: u8::MAX,
|
||||
},
|
||||
Some(index.clone()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut snippet_identifiers = context
|
||||
.declarations
|
||||
.iter()
|
||||
.map(|snippet| snippet.identifier.name.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
snippet_identifiers.sort();
|
||||
assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
|
||||
drop(buffer);
|
||||
}
|
||||
|
||||
ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
|
||||
ranges.dedup_by(|a, b| {
|
||||
if a.start <= b.end {
|
||||
b.start = b.start.min(a.start);
|
||||
b.end = b.end.max(a.end);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
let mut identifiers = Vec::new();
|
||||
let outer_range =
|
||||
ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
|
||||
|
||||
let mut captures = buffer
|
||||
.syntax
|
||||
.captures(outer_range.clone(), &buffer.text, |grammar| {
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
|
||||
for range in ranges {
|
||||
captures.set_byte_range(range.start..outer_range.end);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"a.rs": indoc! {r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = add(x, y);
|
||||
println!("Result: {}", z);
|
||||
}
|
||||
|
||||
let mut last_range = None;
|
||||
while let Some(capture) = captures.peek() {
|
||||
let node_range = capture.node.byte_range();
|
||||
if node_range.start > range.end {
|
||||
break;
|
||||
}
|
||||
let config = captures.grammars()[capture.grammar_index]
|
||||
.highlights_config
|
||||
.as_ref();
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"#},
|
||||
"b.rs": indoc! {"
|
||||
pub struct Config {
|
||||
pub name: String,
|
||||
pub value: i32,
|
||||
}
|
||||
|
||||
if let Some(config) = config
|
||||
&& config.identifier_capture_indices.contains(&capture.index)
|
||||
&& range.contains_inclusive(&node_range)
|
||||
&& Some(&node_range) != last_range.as_ref()
|
||||
{
|
||||
let name = buffer.text_for_range(node_range.clone()).collect();
|
||||
identifiers.push(Identifier {
|
||||
range: buffer.anchor_after(node_range.start)
|
||||
..buffer.anchor_before(node_range.end),
|
||||
name,
|
||||
});
|
||||
last_range = Some(node_range);
|
||||
}
|
||||
impl Config {
|
||||
pub fn new(name: String, value: i32) -> Self {
|
||||
Config { name, value }
|
||||
}
|
||||
}
|
||||
"},
|
||||
"c.rs": indoc! {r#"
|
||||
use std::collections::HashMap;
|
||||
|
||||
captures.advance();
|
||||
}
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let data: Vec<i32> = args[1..]
|
||||
.iter()
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
let result = process_data(data);
|
||||
println!("{:?}", result);
|
||||
}
|
||||
|
||||
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for value in data {
|
||||
*counts.entry(value).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_process_data() {
|
||||
let data = vec![1, 2, 2, 3];
|
||||
let result = process_data(data);
|
||||
assert_eq!(result.get(&2), Some(&2));
|
||||
}
|
||||
}
|
||||
"#}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let lang = rust_lang();
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let file_indexing_parallelism = 2;
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
}
|
||||
|
||||
identifiers
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,510 +0,0 @@
|
||||
use super::*;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{Point, ToPoint as _, rust_lang};
|
||||
use lsp::FakeLanguageServer;
|
||||
use project::{FakeFs, LocationLink, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::fmt::Write as _;
|
||||
use util::{path, test::marked_text_ranges};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), test_project_1()).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let mut servers = setup_fake_lsp(&project, cx);
|
||||
|
||||
let (buffer, _handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _server = servers.next().await.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
|
||||
related_excerpt_store.update(cx, |store, cx| {
|
||||
let position = {
|
||||
let buffer = buffer.read(cx);
|
||||
let offset = buffer.text().find("todo").unwrap();
|
||||
buffer.anchor_before(offset)
|
||||
};
|
||||
|
||||
store.set_identifier_line_count(0);
|
||||
store.refresh(buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
cx.executor().advance_clock(DEBOUNCE_DURATION);
|
||||
related_excerpt_store.update(cx, |store, _| {
|
||||
let excerpts = store.related_files();
|
||||
assert_related_files(
|
||||
&excerpts,
|
||||
&[
|
||||
(
|
||||
"src/company.rs",
|
||||
&[indoc! {"
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
address: Address,
|
||||
}"}],
|
||||
),
|
||||
(
|
||||
"src/main.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
pub struct Session {
|
||||
company: Arc<Company>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn set_company(&mut self, company: Arc<Company>) {"},
|
||||
indoc! {"
|
||||
}
|
||||
}"},
|
||||
],
|
||||
),
|
||||
(
|
||||
"src/person.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
impl Person {
|
||||
pub fn get_first_name(&self) -> &str {
|
||||
&self.first_name
|
||||
}"},
|
||||
"}",
|
||||
],
|
||||
),
|
||||
],
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_assemble_excerpts(cx: &mut TestAppContext) {
|
||||
let table = [
|
||||
(
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
«last_name»: String,
|
||||
age: u32,
|
||||
email: String,
|
||||
create_at: Instant,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn first_name(&self) -> String {
|
||||
self.first_name.clone()
|
||||
}
|
||||
|
||||
pub fn full_name(&self) -> String {
|
||||
« format!("{} {}", self.first_name, self.last_name)
|
||||
» }
|
||||
}
|
||||
"#},
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
…
|
||||
}
|
||||
|
||||
impl User {
|
||||
…
|
||||
pub fn full_name(&self) -> String {
|
||||
format!("{} {}", self.first_name, self.last_name)
|
||||
}
|
||||
}
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
struct «User» {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
impl User {
|
||||
// methods
|
||||
}
|
||||
"#},
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: u32,
|
||||
}
|
||||
…
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
trait «FooProvider» {
|
||||
const NAME: &'static str;
|
||||
|
||||
fn provide_foo(&self, id: usize) -> Foo;
|
||||
|
||||
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
ids.iter()
|
||||
.map(|id| self.provide_foo(*id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#
|
||||
},
|
||||
indoc! {r#"
|
||||
trait FooProvider {
|
||||
const NAME: &'static str;
|
||||
|
||||
fn provide_foo(&self, id: usize) -> Foo;
|
||||
|
||||
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
…
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
trait «Something» {
|
||||
fn method1(&self, id: usize) -> Foo;
|
||||
|
||||
fn method2(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
struct Helper1 {
|
||||
field1: usize,
|
||||
}
|
||||
|
||||
struct Helper2 {
|
||||
field2: usize,
|
||||
}
|
||||
|
||||
struct Helper3 {
|
||||
filed2: usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#
|
||||
},
|
||||
indoc! {r#"
|
||||
trait Something {
|
||||
fn method1(&self, id: usize) -> Foo;
|
||||
|
||||
fn method2(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
…
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#},
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_output) in table {
|
||||
let (input, ranges) = marked_text_ranges(&input, false);
|
||||
let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx));
|
||||
buffer.read_with(cx, |buffer, _cx| {
|
||||
let ranges: Vec<Range<Point>> = ranges
|
||||
.into_iter()
|
||||
.map(|range| range.to_point(&buffer))
|
||||
.collect();
|
||||
|
||||
let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
|
||||
|
||||
let output = format_excerpts(buffer, &excerpts);
|
||||
assert_eq!(output, expected_output);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), test_project_1()).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let mut servers = setup_fake_lsp(&project, cx);
|
||||
|
||||
let (buffer, _handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _server = servers.next().await.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("Address {").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub struct Address {"], cx);
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("State::CA").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub enum State {"], cx);
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("to_string()").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(|cx| SettingsStore::test(cx));
|
||||
cx.set_global(settings_store);
|
||||
env_logger::try_init().ok();
|
||||
}
|
||||
|
||||
fn setup_fake_lsp(
|
||||
project: &Entity<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> UnboundedReceiver<FakeLanguageServer> {
|
||||
let (language_registry, fs) = project.read_with(cx, |project, _| {
|
||||
(project.languages().clone(), project.fs().clone())
|
||||
});
|
||||
let language = rust_lang();
|
||||
language_registry.add(language.clone());
|
||||
fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs)
|
||||
}
|
||||
|
||||
fn test_project_1() -> serde_json::Value {
|
||||
let person_rs = indoc! {r#"
|
||||
pub struct Person {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
email: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
pub fn get_first_name(&self) -> &str {
|
||||
&self.first_name
|
||||
}
|
||||
|
||||
pub fn get_last_name(&self) -> &str {
|
||||
&self.last_name
|
||||
}
|
||||
|
||||
pub fn get_email(&self) -> &str {
|
||||
&self.email
|
||||
}
|
||||
|
||||
pub fn get_age(&self) -> u32 {
|
||||
self.age
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let address_rs = indoc! {r#"
|
||||
pub struct Address {
|
||||
street: String,
|
||||
city: String,
|
||||
state: State,
|
||||
zip: u32,
|
||||
}
|
||||
|
||||
pub enum State {
|
||||
CA,
|
||||
OR,
|
||||
WA,
|
||||
TX,
|
||||
// ...
|
||||
}
|
||||
|
||||
impl Address {
|
||||
pub fn get_street(&self) -> &str {
|
||||
&self.street
|
||||
}
|
||||
|
||||
pub fn get_city(&self) -> &str {
|
||||
&self.city
|
||||
}
|
||||
|
||||
pub fn get_state(&self) -> State {
|
||||
self.state
|
||||
}
|
||||
|
||||
pub fn get_zip(&self) -> u32 {
|
||||
self.zip
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let company_rs = indoc! {r#"
|
||||
use super::person::Person;
|
||||
use super::address::Address;
|
||||
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
address: Address,
|
||||
}
|
||||
|
||||
impl Company {
|
||||
pub fn get_owner(&self) -> &Person {
|
||||
&self.owner
|
||||
}
|
||||
|
||||
pub fn get_address(&self) -> &Address {
|
||||
&self.address
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("{} ({})", self.owner.first_name, self.address.city)
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let main_rs = indoc! {r#"
|
||||
use std::sync::Arc;
|
||||
use super::person::Person;
|
||||
use super::address::Address;
|
||||
use super::company::Company;
|
||||
|
||||
pub struct Session {
|
||||
company: Arc<Company>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn set_company(&mut self, company: Arc<Company>) {
|
||||
self.company = company;
|
||||
if company.owner != self.company.owner {
|
||||
log("new owner", company.owner.get_first_name()); todo();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let company = Company {
|
||||
owner: Arc::new(Person {
|
||||
first_name: "John".to_string(),
|
||||
last_name: "Doe".to_string(),
|
||||
email: "john@example.com".to_string(),
|
||||
age: 30,
|
||||
}),
|
||||
address: Address {
|
||||
street: "123 Main St".to_string(),
|
||||
city: "Anytown".to_string(),
|
||||
state: State::CA,
|
||||
zip: 12345,
|
||||
},
|
||||
};
|
||||
|
||||
println!("Company: {}", company.to_string());
|
||||
}
|
||||
"#};
|
||||
|
||||
json!({
|
||||
"src": {
|
||||
"person.rs": person_rs,
|
||||
"address.rs": address_rs,
|
||||
"company.rs": company_rs,
|
||||
"main.rs": main_rs,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) {
|
||||
let actual_files = actual_files
|
||||
.iter()
|
||||
.map(|file| {
|
||||
let excerpts = file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.text.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
(file.path.path.as_unix_str(), excerpts)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_excerpts = expected_files
|
||||
.iter()
|
||||
.map(|(path, texts)| {
|
||||
(
|
||||
*path,
|
||||
texts
|
||||
.iter()
|
||||
.map(|line| line.to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
pretty_assertions::assert_eq!(actual_files, expected_excerpts)
|
||||
}
|
||||
|
||||
fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
|
||||
let actual_first_lines = definitions
|
||||
.iter()
|
||||
.map(|definition| {
|
||||
definition.target.buffer.read_with(cx, |buffer, _| {
|
||||
let mut start = definition.target.range.start.to_point(&buffer);
|
||||
start.column = 0;
|
||||
let end = Point::new(start.row, buffer.line_len(start.row));
|
||||
buffer
|
||||
.text_for_range(start..end)
|
||||
.collect::<String>()
|
||||
.trim()
|
||||
.to_string()
|
||||
})
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
assert_eq!(actual_first_lines, first_lines);
|
||||
}
|
||||
|
||||
fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
|
||||
let mut output = String::new();
|
||||
let file_line_count = buffer.max_point().row;
|
||||
let mut current_row = 0;
|
||||
for excerpt in excerpts {
|
||||
if excerpt.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if current_row < excerpt.point_range.start.row {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
current_row = excerpt.point_range.start.row;
|
||||
|
||||
for line in excerpt.text.to_string().lines() {
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
current_row += 1;
|
||||
}
|
||||
}
|
||||
if current_row < file_line_count {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
output
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
use cloud_llm_client::predict_edits_v3::Line;
|
||||
use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _};
|
||||
use language::{BufferSnapshot, LanguageId};
|
||||
use std::ops::Range;
|
||||
use text::{Point, ToOffset as _, ToPoint as _};
|
||||
use tree_sitter::{Node, TreeCursor};
|
||||
use util::RangeExt;
|
||||
|
||||
use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// - Test parent signatures
|
||||
@@ -29,16 +31,19 @@ pub struct EditPredictionExcerptOptions {
|
||||
pub target_before_cursor_over_total_bytes: f32,
|
||||
}
|
||||
|
||||
// TODO: consider merging these
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerpt {
|
||||
pub range: Range<usize>,
|
||||
pub line_range: Range<Line>,
|
||||
pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerptText {
|
||||
pub body: String,
|
||||
pub parent_signatures: Vec<String>,
|
||||
pub language_id: Option<LanguageId>,
|
||||
}
|
||||
|
||||
@@ -47,8 +52,17 @@ impl EditPredictionExcerpt {
|
||||
let body = buffer
|
||||
.text_for_range(self.range.clone())
|
||||
.collect::<String>();
|
||||
let parent_signatures = self
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect();
|
||||
let language_id = buffer.language().map(|l| l.id());
|
||||
EditPredictionExcerptText { body, language_id }
|
||||
EditPredictionExcerptText {
|
||||
body,
|
||||
parent_signatures,
|
||||
language_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
|
||||
@@ -65,6 +79,7 @@ impl EditPredictionExcerpt {
|
||||
query_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
options: &EditPredictionExcerptOptions,
|
||||
syntax_index: Option<&SyntaxIndexState>,
|
||||
) -> Option<Self> {
|
||||
if buffer.len() <= options.max_bytes {
|
||||
log::debug!(
|
||||
@@ -74,7 +89,11 @@ impl EditPredictionExcerpt {
|
||||
);
|
||||
let offset_range = 0..buffer.len();
|
||||
let line_range = Line(0)..Line(buffer.max_point().row);
|
||||
return Some(EditPredictionExcerpt::new(offset_range, line_range));
|
||||
return Some(EditPredictionExcerpt::new(
|
||||
offset_range,
|
||||
line_range,
|
||||
Vec::new(),
|
||||
));
|
||||
}
|
||||
|
||||
let query_offset = query_point.to_offset(buffer);
|
||||
@@ -85,10 +104,19 @@ impl EditPredictionExcerpt {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parent_declarations = if let Some(syntax_index) = syntax_index {
|
||||
syntax_index
|
||||
.buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let excerpt_selector = ExcerptSelector {
|
||||
query_offset,
|
||||
query_range,
|
||||
query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
|
||||
parent_declarations: &parent_declarations,
|
||||
buffer,
|
||||
options,
|
||||
};
|
||||
@@ -111,10 +139,20 @@ impl EditPredictionExcerpt {
|
||||
excerpt_selector.select_lines()
|
||||
}
|
||||
|
||||
fn new(range: Range<usize>, line_range: Range<Line>) -> Self {
|
||||
fn new(
|
||||
range: Range<usize>,
|
||||
line_range: Range<Line>,
|
||||
parent_declarations: Vec<(DeclarationId, Range<usize>)>,
|
||||
) -> Self {
|
||||
let size = range.len()
|
||||
+ parent_declarations
|
||||
.iter()
|
||||
.map(|(_, range)| range.len())
|
||||
.sum::<usize>();
|
||||
Self {
|
||||
size: range.len(),
|
||||
range,
|
||||
parent_declarations,
|
||||
size,
|
||||
line_range,
|
||||
}
|
||||
}
|
||||
@@ -124,7 +162,14 @@ impl EditPredictionExcerpt {
|
||||
// this is an issue because parent_signature_ranges may be incorrect
|
||||
log::error!("bug: with_expanded_range called with disjoint range");
|
||||
}
|
||||
Self::new(new_range, new_line_range)
|
||||
let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
|
||||
for (declaration_id, range) in &self.parent_declarations {
|
||||
if !range.contains_inclusive(&new_range) {
|
||||
break;
|
||||
}
|
||||
parent_declarations.push((*declaration_id, range.clone()));
|
||||
}
|
||||
Self::new(new_range, new_line_range, parent_declarations)
|
||||
}
|
||||
|
||||
fn parent_signatures_size(&self) -> usize {
|
||||
@@ -136,6 +181,7 @@ struct ExcerptSelector<'a> {
|
||||
query_offset: usize,
|
||||
query_range: Range<usize>,
|
||||
query_line_range: Range<Line>,
|
||||
parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
|
||||
buffer: &'a BufferSnapshot,
|
||||
options: &'a EditPredictionExcerptOptions,
|
||||
}
|
||||
@@ -363,7 +409,13 @@ impl<'a> ExcerptSelector<'a> {
|
||||
}
|
||||
|
||||
fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
|
||||
EditPredictionExcerpt::new(range, line_range)
|
||||
let parent_declarations = self
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
|
||||
.map(|(id, declaration)| (*id, declaration.signature_range.clone()))
|
||||
.collect();
|
||||
EditPredictionExcerpt::new(range, line_range, parent_declarations)
|
||||
}
|
||||
|
||||
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
|
||||
@@ -419,14 +471,30 @@ fn node_line_end(node: Node) -> Point {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use language::Buffer;
|
||||
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use util::test::{generate_marked_text, marked_text_offsets_by};
|
||||
|
||||
fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot())
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
|
||||
let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
|
||||
(text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
|
||||
@@ -438,8 +506,9 @@ mod tests {
|
||||
let buffer = create_buffer(&text, cx);
|
||||
let cursor_point = cursor.to_point(&buffer);
|
||||
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
|
||||
.expect("Should select an excerpt");
|
||||
let excerpt =
|
||||
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
|
||||
.expect("Should select an excerpt");
|
||||
pretty_assertions::assert_eq!(
|
||||
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
|
||||
generate_marked_text(&text, &[expected_excerpt], false)
|
||||
|
||||
@@ -1,329 +0,0 @@
|
||||
use collections::HashMap;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use language::{Language, LanguageRegistry};
|
||||
use lsp::{
|
||||
FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use project::Fs;
|
||||
use std::{ops::Range, path::PathBuf, sync::Arc};
|
||||
use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree};
|
||||
|
||||
/// Registers a fake language server that implements go-to-definition using tree-sitter,
|
||||
/// making the assumption that all names are unique, and all variables' types are
|
||||
/// explicitly declared.
|
||||
pub fn register_fake_definition_server(
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
language: Arc<Language>,
|
||||
fs: Arc<dyn Fs>,
|
||||
) -> UnboundedReceiver<FakeLanguageServer> {
|
||||
let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone())));
|
||||
|
||||
language_registry.register_fake_lsp(
|
||||
language.name(),
|
||||
language::FakeLspAdapter {
|
||||
name: "fake-definition-lsp",
|
||||
initialization_options: None,
|
||||
prettier_plugins: Vec::new(),
|
||||
disk_based_diagnostics_progress_token: None,
|
||||
disk_based_diagnostics_sources: Vec::new(),
|
||||
language_server_binary: LanguageServerBinary {
|
||||
path: PathBuf::from("fake-definition-lsp"),
|
||||
arguments: Vec::new(),
|
||||
env: None,
|
||||
},
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
definition_provider: Some(lsp::OneOf::Left(true)),
|
||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::FULL,
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
label_for_completion: None,
|
||||
initializer: Some(Box::new({
|
||||
move |server| {
|
||||
server.handle_notification::<lsp::notification::DidOpenTextDocument, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
index
|
||||
.lock()
|
||||
.open_buffer(params.text_document.uri, ¶ms.text_document.text);
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidCloseTextDocument, _>({
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let uri = params.text_document.uri;
|
||||
let path = uri.to_file_path().ok();
|
||||
index.lock().mark_buffer_closed(&uri);
|
||||
|
||||
if let Some(path) = path {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
cx.spawn(async move |_cx| {
|
||||
if let Ok(content) = fs.load(&path).await {
|
||||
index.lock().index_file(uri, &content);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeWatchedFiles, _>({
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
cx.spawn(async move |_cx| {
|
||||
for event in params.changes {
|
||||
if index.lock().is_buffer_open(&event.uri) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match event.typ {
|
||||
lsp::FileChangeType::DELETED => {
|
||||
index.lock().remove_definitions_for_file(&event.uri);
|
||||
}
|
||||
lsp::FileChangeType::CREATED
|
||||
| lsp::FileChangeType::CHANGED => {
|
||||
if let Some(path) = event.uri.to_file_path().ok() {
|
||||
if let Ok(content) = fs.load(&path).await {
|
||||
index.lock().index_file(event.uri, &content);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeTextDocument, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
if let Some(change) = params.content_changes.into_iter().last() {
|
||||
index
|
||||
.lock()
|
||||
.index_file(params.text_document.uri, &change.text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeWorkspaceFolders, _>(
|
||||
{
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
let files = fs.as_fake().files();
|
||||
cx.spawn(async move |_cx| {
|
||||
for folder in params.event.added {
|
||||
let Ok(path) = folder.uri.to_file_path() else {
|
||||
continue;
|
||||
};
|
||||
for file in &files {
|
||||
if let Some(uri) = Uri::from_file_path(&file).ok()
|
||||
&& file.starts_with(&path)
|
||||
&& let Ok(content) = fs.load(&file).await
|
||||
{
|
||||
index.lock().index_file(uri, &content);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
server.set_request_handler::<lsp::request::GotoDefinition, _, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
let result = index.lock().get_definitions(
|
||||
params.text_document_position_params.text_document.uri,
|
||||
params.text_document_position_params.position,
|
||||
);
|
||||
async move { Ok(result) }
|
||||
}
|
||||
});
|
||||
}
|
||||
})),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
struct DefinitionIndex {
|
||||
language: Arc<Language>,
|
||||
definitions: HashMap<String, Vec<lsp::Location>>,
|
||||
files: HashMap<Uri, FileEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FileEntry {
|
||||
contents: String,
|
||||
is_open_in_buffer: bool,
|
||||
}
|
||||
|
||||
impl DefinitionIndex {
|
||||
fn new(language: Arc<Language>) -> Self {
|
||||
Self {
|
||||
language,
|
||||
definitions: HashMap::default(),
|
||||
files: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_definitions_for_file(&mut self, uri: &Uri) {
|
||||
self.definitions.retain(|_, locations| {
|
||||
locations.retain(|loc| &loc.uri != uri);
|
||||
!locations.is_empty()
|
||||
});
|
||||
self.files.remove(uri);
|
||||
}
|
||||
|
||||
fn open_buffer(&mut self, uri: Uri, content: &str) {
|
||||
self.index_file_inner(uri, content, true);
|
||||
}
|
||||
|
||||
fn mark_buffer_closed(&mut self, uri: &Uri) {
|
||||
if let Some(entry) = self.files.get_mut(uri) {
|
||||
entry.is_open_in_buffer = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn is_buffer_open(&self, uri: &Uri) -> bool {
|
||||
self.files
|
||||
.get(uri)
|
||||
.map(|entry| entry.is_open_in_buffer)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn index_file(&mut self, uri: Uri, content: &str) {
|
||||
self.index_file_inner(uri, content, false);
|
||||
}
|
||||
|
||||
fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> {
|
||||
self.remove_definitions_for_file(&uri);
|
||||
let grammar = self.language.grammar()?;
|
||||
let outline_config = grammar.outline_config.as_ref()?;
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(&grammar.ts_language).ok()?;
|
||||
let tree = parser.parse(content, None)?;
|
||||
let declarations = extract_declarations_from_tree(&tree, content, outline_config);
|
||||
for (name, byte_range) in declarations {
|
||||
let range = byte_range_to_lsp_range(content, byte_range);
|
||||
let location = lsp::Location {
|
||||
uri: uri.clone(),
|
||||
range,
|
||||
};
|
||||
self.definitions
|
||||
.entry(name)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(location);
|
||||
}
|
||||
self.files.insert(
|
||||
uri,
|
||||
FileEntry {
|
||||
contents: content.to_string(),
|
||||
is_open_in_buffer,
|
||||
},
|
||||
);
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
fn get_definitions(
|
||||
&mut self,
|
||||
uri: Uri,
|
||||
position: lsp::Position,
|
||||
) -> Option<lsp::GotoDefinitionResponse> {
|
||||
let entry = self.files.get(&uri)?;
|
||||
let name = word_at_position(&entry.contents, position)?;
|
||||
let locations = self.definitions.get(name).cloned()?;
|
||||
Some(lsp::GotoDefinitionResponse::Array(locations))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_declarations_from_tree(
|
||||
tree: &Tree,
|
||||
content: &str,
|
||||
outline_config: &language::OutlineConfig,
|
||||
) -> Vec<(String, Range<usize>)> {
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut declarations = Vec::new();
|
||||
let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes());
|
||||
while let Some(query_match) = matches.next() {
|
||||
let mut name_range: Option<Range<usize>> = None;
|
||||
let mut has_item_range = false;
|
||||
|
||||
for capture in query_match.captures {
|
||||
let range = capture.node.byte_range();
|
||||
if capture.index == outline_config.name_capture_ix {
|
||||
name_range = Some(range);
|
||||
} else if capture.index == outline_config.item_capture_ix {
|
||||
has_item_range = true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(name_range) = name_range
|
||||
&& has_item_range
|
||||
{
|
||||
let name = content[name_range.clone()].to_string();
|
||||
if declarations.iter().any(|(n, _)| n == &name) {
|
||||
continue;
|
||||
}
|
||||
declarations.push((name, name_range));
|
||||
}
|
||||
}
|
||||
declarations
|
||||
}
|
||||
|
||||
fn byte_range_to_lsp_range(content: &str, byte_range: Range<usize>) -> lsp::Range {
|
||||
let start = byte_offset_to_position(content, byte_range.start);
|
||||
let end = byte_offset_to_position(content, byte_range.end);
|
||||
lsp::Range { start, end }
|
||||
}
|
||||
|
||||
fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position {
|
||||
let mut line = 0;
|
||||
let mut character = 0;
|
||||
let mut current_offset = 0;
|
||||
for ch in content.chars() {
|
||||
if current_offset >= offset {
|
||||
break;
|
||||
}
|
||||
if ch == '\n' {
|
||||
line += 1;
|
||||
character = 0;
|
||||
} else {
|
||||
character += 1;
|
||||
}
|
||||
current_offset += ch.len_utf8();
|
||||
}
|
||||
lsp::Position { line, character }
|
||||
}
|
||||
|
||||
fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> {
|
||||
let mut lines = content.lines();
|
||||
let line = lines.nth(position.line as usize)?;
|
||||
let column = position.character as usize;
|
||||
if column > line.len() {
|
||||
return None;
|
||||
}
|
||||
let start = line[..column]
|
||||
.rfind(|c: char| !c.is_alphanumeric() && c != '_')
|
||||
.map(|i| i + 1)
|
||||
.unwrap_or(0);
|
||||
let end = line[column..]
|
||||
.find(|c: char| !c.is_alphanumeric() && c != '_')
|
||||
.map(|i| i + column)
|
||||
.unwrap_or(line.len());
|
||||
Some(&line[start..end]).filter(|word| !word.is_empty())
|
||||
}
|
||||
1319
crates/edit_prediction_context/src/imports.rs
Normal file
1319
crates/edit_prediction_context/src/imports.rs
Normal file
File diff suppressed because it is too large
Load Diff
126
crates/edit_prediction_context/src/outline.rs
Normal file
126
crates/edit_prediction_context/src/outline.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use language::{BufferSnapshot, SyntaxMapMatches};
|
||||
use std::{cmp::Reverse, ops::Range};
|
||||
|
||||
use crate::declaration::Identifier;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * how to handle multiple name captures? for now last one wins
|
||||
//
|
||||
// * annotation ranges
|
||||
//
|
||||
// * new "signature" capture for outline queries
|
||||
//
|
||||
// * Check parent behavior of "int x, y = 0" declarations in a test
|
||||
|
||||
pub struct OutlineDeclaration {
|
||||
pub parent_index: Option<usize>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub signature_range: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
|
||||
declarations_overlapping_range(0..buffer.len(), buffer)
|
||||
}
|
||||
|
||||
pub fn declarations_overlapping_range(
|
||||
range: Range<usize>,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Vec<OutlineDeclaration> {
|
||||
let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
|
||||
declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
|
||||
|
||||
let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
|
||||
for (index, declaration) in declarations.iter_mut().enumerate() {
|
||||
while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
|
||||
if declaration.item_range.start >= top_parent_range.end {
|
||||
parent_stack.pop();
|
||||
} else {
|
||||
declaration.parent_index = Some(*top_parent_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
parent_stack.push((index, declaration.item_range.clone()));
|
||||
}
|
||||
declarations
|
||||
}
|
||||
|
||||
/// Iterates outline items without being ordered w.r.t. nested items and without populating
|
||||
/// `parent`.
|
||||
pub struct OutlineIterator<'a> {
|
||||
buffer: &'a BufferSnapshot,
|
||||
matches: SyntaxMapMatches<'a>,
|
||||
}
|
||||
|
||||
impl<'a> OutlineIterator<'a> {
|
||||
pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
|
||||
let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
|
||||
grammar.outline_config.as_ref().map(|c| &c.query)
|
||||
});
|
||||
|
||||
Self { buffer, matches }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for OutlineIterator<'a> {
|
||||
type Item = OutlineDeclaration;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(mat) = self.matches.peek() {
|
||||
let config = self.matches.grammars()[mat.grammar_index]
|
||||
.outline_config
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let mut name_range = None;
|
||||
let mut item_range = None;
|
||||
let mut signature_start = None;
|
||||
let mut signature_end = None;
|
||||
|
||||
let mut add_to_signature = |range: Range<usize>| {
|
||||
if signature_start.is_none() {
|
||||
signature_start = Some(range.start);
|
||||
}
|
||||
signature_end = Some(range.end);
|
||||
};
|
||||
|
||||
for capture in mat.captures {
|
||||
let range = capture.node.byte_range();
|
||||
if capture.index == config.name_capture_ix {
|
||||
name_range = Some(range.clone());
|
||||
add_to_signature(range);
|
||||
} else if Some(capture.index) == config.context_capture_ix
|
||||
|| Some(capture.index) == config.extra_context_capture_ix
|
||||
{
|
||||
add_to_signature(range);
|
||||
} else if capture.index == config.item_capture_ix {
|
||||
item_range = Some(range.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let language_id = mat.language.id();
|
||||
self.matches.advance();
|
||||
|
||||
if let Some(name_range) = name_range
|
||||
&& let Some(item_range) = item_range
|
||||
&& let Some(signature_start) = signature_start
|
||||
&& let Some(signature_end) = signature_end
|
||||
{
|
||||
let name = self
|
||||
.buffer
|
||||
.text_for_range(name_range)
|
||||
.collect::<String>()
|
||||
.into();
|
||||
|
||||
return Some(OutlineDeclaration {
|
||||
identifier: Identifier { name, language_id },
|
||||
item_range: item_range,
|
||||
signature_range: signature_start..signature_end,
|
||||
parent_index: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user