Compare commits
29 Commits
fix-linux-
...
print-live
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c90c1880bb | ||
|
|
822fc7ef16 | ||
|
|
126d708fa1 | ||
|
|
a5ab5c7d5d | ||
|
|
35da6d000a | ||
|
|
d6241b17d3 | ||
|
|
42583c1141 | ||
|
|
76167109db | ||
|
|
cd8679e81a | ||
|
|
43f977c6b9 | ||
|
|
bdb8caa42e | ||
|
|
9ae77ec3c9 | ||
|
|
d5ed9d3e3a | ||
|
|
74a1b5d14d | ||
|
|
07af011eb4 | ||
|
|
c357dc25fc | ||
|
|
93bc6616c6 | ||
|
|
a33e881906 | ||
|
|
c978db8626 | ||
|
|
2dad46c5c0 | ||
|
|
4c51fffbb5 | ||
|
|
0d80b452fb | ||
|
|
bad6bde03a | ||
|
|
4ec2d04ad9 | ||
|
|
0f0017dc8e | ||
|
|
9db0d66251 | ||
|
|
b07389d9f3 | ||
|
|
db2e26f67b | ||
|
|
391c92b07a |
1305
Cargo.lock
generated
1305
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
16
Cargo.toml
16
Cargo.toml
@@ -54,9 +54,9 @@ members = [
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
"crates/edit_prediction",
|
||||
"crates/edit_prediction_button",
|
||||
"crates/edit_prediction_types",
|
||||
"crates/edit_prediction_ui",
|
||||
"crates/edit_prediction_context",
|
||||
"crates/zeta2_tools",
|
||||
"crates/editor",
|
||||
"crates/eval",
|
||||
"crates/eval_utils",
|
||||
@@ -201,8 +201,7 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zed_env_vars",
|
||||
"crates/zeta",
|
||||
"crates/zeta_cli",
|
||||
"crates/edit_prediction_cli",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
|
||||
@@ -313,10 +312,9 @@ http_client = { path = "crates/http_client" }
|
||||
http_client_tls = { path = "crates/http_client_tls" }
|
||||
icons = { path = "crates/icons" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
edit_prediction_button = { path = "crates/edit_prediction_button" }
|
||||
edit_prediction_types = { path = "crates/edit_prediction_types" }
|
||||
edit_prediction_ui = { path = "crates/edit_prediction_ui" }
|
||||
edit_prediction_context = { path = "crates/edit_prediction_context" }
|
||||
zeta2_tools = { path = "crates/zeta2_tools" }
|
||||
inspector_ui = { path = "crates/inspector_ui" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
journal = { path = "crates/journal" }
|
||||
@@ -433,7 +431,7 @@ x_ai = { path = "crates/x_ai" }
|
||||
zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
zeta = { path = "crates/zeta" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zlog = { path = "crates/zlog" }
|
||||
zlog_settings = { path = "crates/zlog_settings" }
|
||||
|
||||
@@ -828,7 +826,7 @@ feature_flags = { codegen-units = 1 }
|
||||
file_icons = { codegen-units = 1 }
|
||||
fsevent = { codegen-units = 1 }
|
||||
image_viewer = { codegen-units = 1 }
|
||||
edit_prediction_button = { codegen-units = 1 }
|
||||
edit_prediction_ui = { codegen-units = 1 }
|
||||
install_cli = { codegen-units = 1 }
|
||||
journal = { codegen-units = 1 }
|
||||
json_schema_store = { codegen-units = 1 }
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
"ctrl-f11": "debugger::StepInto",
|
||||
"shift-f11": "debugger::StepOut",
|
||||
"f11": "zed::ToggleFullScreen",
|
||||
"ctrl-alt-z": "edit_prediction::RateCompletions",
|
||||
"ctrl-alt-z": "edit_prediction::RatePredictions",
|
||||
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
|
||||
"ctrl-alt-l": "lsp_tool::ToggleMenu"
|
||||
}
|
||||
@@ -1322,25 +1322,18 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"bindings": {
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"ctrl-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch"
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"ctrl-shift-i": "branch_picker::FilterRemotes"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
"cmd-m": "zed::Minimize",
|
||||
"fn-f": "zed::ToggleFullScreen",
|
||||
"ctrl-cmd-f": "zed::ToggleFullScreen",
|
||||
"ctrl-cmd-z": "edit_prediction::RateCompletions",
|
||||
"ctrl-cmd-z": "edit_prediction::RatePredictions",
|
||||
"ctrl-cmd-i": "edit_prediction::ToggleMenu",
|
||||
"ctrl-cmd-l": "lsp_tool::ToggleMenu",
|
||||
"ctrl-cmd-c": "editor::DisplayCursorNames"
|
||||
@@ -1427,25 +1427,18 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"bindings": {
|
||||
"enter": "editor::Newline",
|
||||
"cmd-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"cmd-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-shift-backspace": "branch_picker::DeleteBranch"
|
||||
"cmd-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"cmd-shift-i": "branch_picker::FilterRemotes"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1341,25 +1341,18 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Feedback > Editor",
|
||||
"context": "EditPredictionContext > Editor",
|
||||
"bindings": {
|
||||
"enter": "editor::Newline",
|
||||
"ctrl-enter up": "dev::Zeta2RatePredictionPositive",
|
||||
"ctrl-enter down": "dev::Zeta2RatePredictionNegative"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Zeta2Context > Editor",
|
||||
"bindings": {
|
||||
"alt-left": "dev::Zeta2ContextGoBack",
|
||||
"alt-right": "dev::Zeta2ContextGoForward"
|
||||
"alt-left": "dev::EditPredictionContextGoBack",
|
||||
"alt-right": "dev::EditPredictionContextGoForward"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch"
|
||||
"ctrl-shift-backspace": "branch_picker::DeleteBranch",
|
||||
"ctrl-shift-i": "branch_picker::FilterRemotes"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -584,41 +584,100 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cross_region_inference_id(&self, region: &str) -> anyhow::Result<String> {
|
||||
pub fn cross_region_inference_id(
|
||||
&self,
|
||||
region: &str,
|
||||
allow_global: bool,
|
||||
) -> anyhow::Result<String> {
|
||||
// List derived from here:
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html#inference-profiles-support-system
|
||||
let model_id = self.request_id();
|
||||
|
||||
let supports_global = matches!(
|
||||
self,
|
||||
Model::ClaudeOpus4_5
|
||||
| Model::ClaudeOpus4_5Thinking
|
||||
| Model::ClaudeHaiku4_5
|
||||
| Model::ClaudeSonnet4
|
||||
| Model::ClaudeSonnet4Thinking
|
||||
| Model::ClaudeSonnet4_5
|
||||
| Model::ClaudeSonnet4_5Thinking
|
||||
);
|
||||
|
||||
let region_group = if region.starts_with("us-gov-") {
|
||||
"us-gov"
|
||||
} else if region.starts_with("us-") {
|
||||
"us"
|
||||
} else if region.starts_with("us-")
|
||||
|| region.starts_with("ca-")
|
||||
|| region.starts_with("sa-")
|
||||
{
|
||||
if allow_global && supports_global {
|
||||
"global"
|
||||
} else {
|
||||
"us"
|
||||
}
|
||||
} else if region.starts_with("eu-") {
|
||||
"eu"
|
||||
if allow_global && supports_global {
|
||||
"global"
|
||||
} else {
|
||||
"eu"
|
||||
}
|
||||
} else if region.starts_with("ap-") || region == "me-central-1" || region == "me-south-1" {
|
||||
"apac"
|
||||
} else if region.starts_with("ca-") || region.starts_with("sa-") {
|
||||
// Canada and South America regions - default to US profiles
|
||||
"us"
|
||||
if allow_global && supports_global {
|
||||
"global"
|
||||
} else {
|
||||
"apac"
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("Unsupported Region {region}");
|
||||
};
|
||||
|
||||
let model_id = self.request_id();
|
||||
match (self, region_group, region) {
|
||||
(Model::Custom { .. }, _, _) => Ok(self.request_id().into()),
|
||||
|
||||
match (self, region_group) {
|
||||
// Custom models can't have CRI IDs
|
||||
(Model::Custom { .. }, _) => Ok(self.request_id().into()),
|
||||
(
|
||||
Model::ClaudeOpus4_5
|
||||
| Model::ClaudeOpus4_5Thinking
|
||||
| Model::ClaudeHaiku4_5
|
||||
| Model::ClaudeSonnet4
|
||||
| Model::ClaudeSonnet4Thinking
|
||||
| Model::ClaudeSonnet4_5
|
||||
| Model::ClaudeSonnet4_5Thinking,
|
||||
"global",
|
||||
_,
|
||||
) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
// Models with US Gov only
|
||||
(Model::Claude3_5Sonnet, "us-gov") | (Model::Claude3Haiku, "us-gov") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
(
|
||||
Model::Claude3Haiku
|
||||
| Model::Claude3_5Sonnet
|
||||
| Model::Claude3_7Sonnet
|
||||
| Model::Claude3_7SonnetThinking
|
||||
| Model::ClaudeSonnet4_5
|
||||
| Model::ClaudeSonnet4_5Thinking,
|
||||
"us-gov",
|
||||
_,
|
||||
) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
(
|
||||
Model::ClaudeHaiku4_5 | Model::ClaudeSonnet4_5 | Model::ClaudeSonnet4_5Thinking,
|
||||
"apac",
|
||||
"ap-southeast-2" | "ap-southeast-4",
|
||||
) => Ok(format!("au.{}", model_id)),
|
||||
|
||||
(
|
||||
Model::ClaudeHaiku4_5 | Model::ClaudeSonnet4_5 | Model::ClaudeSonnet4_5Thinking,
|
||||
"apac",
|
||||
"ap-northeast-1" | "ap-northeast-3",
|
||||
) => Ok(format!("jp.{}", model_id)),
|
||||
|
||||
(Model::AmazonNovaLite, "us", r) if r.starts_with("ca-") => {
|
||||
Ok(format!("ca.{}", model_id))
|
||||
}
|
||||
|
||||
// Available everywhere
|
||||
(Model::AmazonNovaLite | Model::AmazonNovaMicro | Model::AmazonNovaPro, _) => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Models in US
|
||||
(
|
||||
Model::AmazonNovaPremier
|
||||
| Model::AmazonNovaLite
|
||||
| Model::AmazonNovaMicro
|
||||
| Model::AmazonNovaPro
|
||||
| Model::Claude3_5Haiku
|
||||
| Model::ClaudeHaiku4_5
|
||||
| Model::Claude3_5Sonnet
|
||||
@@ -655,16 +714,18 @@ impl Model {
|
||||
| Model::PalmyraWriterX4
|
||||
| Model::PalmyraWriterX5,
|
||||
"us",
|
||||
_,
|
||||
) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
// Models available in EU
|
||||
(
|
||||
Model::Claude3_5Sonnet
|
||||
Model::AmazonNovaLite
|
||||
| Model::AmazonNovaMicro
|
||||
| Model::AmazonNovaPro
|
||||
| Model::Claude3_5Sonnet
|
||||
| Model::ClaudeHaiku4_5
|
||||
| Model::Claude3_7Sonnet
|
||||
| Model::Claude3_7SonnetThinking
|
||||
| Model::ClaudeSonnet4
|
||||
| Model::ClaudeSonnet4Thinking
|
||||
| Model::ClaudeSonnet4_5
|
||||
| Model::ClaudeSonnet4_5Thinking
|
||||
| Model::Claude3Haiku
|
||||
@@ -673,26 +734,26 @@ impl Model {
|
||||
| Model::MetaLlama323BInstructV1
|
||||
| Model::MistralPixtralLarge2502V1,
|
||||
"eu",
|
||||
_,
|
||||
) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
// Models available in APAC
|
||||
(
|
||||
Model::Claude3_5Sonnet
|
||||
Model::AmazonNovaLite
|
||||
| Model::AmazonNovaMicro
|
||||
| Model::AmazonNovaPro
|
||||
| Model::Claude3_5Sonnet
|
||||
| Model::Claude3_5SonnetV2
|
||||
| Model::ClaudeHaiku4_5
|
||||
| Model::Claude3Haiku
|
||||
| Model::Claude3Sonnet
|
||||
| Model::Claude3_7Sonnet
|
||||
| Model::Claude3_7SonnetThinking
|
||||
| Model::ClaudeSonnet4
|
||||
| Model::ClaudeSonnet4Thinking
|
||||
| Model::ClaudeSonnet4_5
|
||||
| Model::ClaudeSonnet4_5Thinking,
|
||||
| Model::Claude3Haiku
|
||||
| Model::Claude3Sonnet,
|
||||
"apac",
|
||||
_,
|
||||
) => Ok(format!("{}.{}", region_group, model_id)),
|
||||
|
||||
// Any other combination is not supported
|
||||
_ => Ok(self.request_id().into()),
|
||||
_ => Ok(model_id.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -705,15 +766,15 @@ mod tests {
|
||||
fn test_us_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test US regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1")?,
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-east-1", false)?,
|
||||
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2")?,
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("us-west-2", false)?,
|
||||
"us.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaPro.cross_region_inference_id("us-east-2")?,
|
||||
Model::AmazonNovaPro.cross_region_inference_id("us-east-2", false)?,
|
||||
"us.amazon.nova-pro-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -723,19 +784,19 @@ mod tests {
|
||||
fn test_eu_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test European regions
|
||||
assert_eq!(
|
||||
Model::ClaudeSonnet4.cross_region_inference_id("eu-west-1")?,
|
||||
Model::ClaudeSonnet4.cross_region_inference_id("eu-west-1", false)?,
|
||||
"eu.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::ClaudeSonnet4_5.cross_region_inference_id("eu-west-1")?,
|
||||
Model::ClaudeSonnet4_5.cross_region_inference_id("eu-west-1", false)?,
|
||||
"eu.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3Sonnet.cross_region_inference_id("eu-west-1")?,
|
||||
Model::Claude3Sonnet.cross_region_inference_id("eu-west-1", false)?,
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1")?,
|
||||
Model::AmazonNovaMicro.cross_region_inference_id("eu-north-1", false)?,
|
||||
"eu.amazon.nova-micro-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -745,15 +806,15 @@ mod tests {
|
||||
fn test_apac_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Asia-Pacific regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1")?,
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-northeast-1", false)?,
|
||||
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-southeast-2")?,
|
||||
Model::Claude3_5SonnetV2.cross_region_inference_id("ap-southeast-2", false)?,
|
||||
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AmazonNovaLite.cross_region_inference_id("ap-south-1")?,
|
||||
Model::AmazonNovaLite.cross_region_inference_id("ap-south-1", false)?,
|
||||
"apac.amazon.nova-lite-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -763,11 +824,11 @@ mod tests {
|
||||
fn test_gov_region_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Government regions
|
||||
assert_eq!(
|
||||
Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1")?,
|
||||
Model::Claude3_5Sonnet.cross_region_inference_id("us-gov-east-1", false)?,
|
||||
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1")?,
|
||||
Model::Claude3Haiku.cross_region_inference_id("us-gov-west-1", false)?,
|
||||
"us-gov.anthropic.claude-3-haiku-20240307-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -777,15 +838,15 @@ mod tests {
|
||||
fn test_meta_models_inference_ids() -> anyhow::Result<()> {
|
||||
// Test Meta models
|
||||
assert_eq!(
|
||||
Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1")?,
|
||||
Model::MetaLlama370BInstructV1.cross_region_inference_id("us-east-1", false)?,
|
||||
"meta.llama3-70b-instruct-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::MetaLlama3170BInstructV1.cross_region_inference_id("us-east-1")?,
|
||||
Model::MetaLlama3170BInstructV1.cross_region_inference_id("us-east-1", false)?,
|
||||
"us.meta.llama3-1-70b-instruct-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1")?,
|
||||
Model::MetaLlama321BInstructV1.cross_region_inference_id("eu-west-1", false)?,
|
||||
"eu.meta.llama3-2-1b-instruct-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -796,11 +857,11 @@ mod tests {
|
||||
// Mistral models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1")?,
|
||||
Model::MistralMistralLarge2402V1.cross_region_inference_id("us-east-1", false)?,
|
||||
"mistral.mistral-large-2402-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1")?,
|
||||
Model::MistralMixtral8x7BInstructV0.cross_region_inference_id("eu-west-1", false)?,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1"
|
||||
);
|
||||
Ok(())
|
||||
@@ -811,11 +872,11 @@ mod tests {
|
||||
// AI21 models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::AI21J2UltraV1.cross_region_inference_id("us-east-1")?,
|
||||
Model::AI21J2UltraV1.cross_region_inference_id("us-east-1", false)?,
|
||||
"ai21.j2-ultra-v1"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1")?,
|
||||
Model::AI21JambaInstructV1.cross_region_inference_id("eu-west-1", false)?,
|
||||
"ai21.jamba-instruct-v1:0"
|
||||
);
|
||||
Ok(())
|
||||
@@ -826,11 +887,11 @@ mod tests {
|
||||
// Cohere models don't follow the regional prefix pattern,
|
||||
// so they should return their original IDs
|
||||
assert_eq!(
|
||||
Model::CohereCommandRV1.cross_region_inference_id("us-east-1")?,
|
||||
Model::CohereCommandRV1.cross_region_inference_id("us-east-1", false)?,
|
||||
"cohere.command-r-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1")?,
|
||||
Model::CohereCommandTextV14_4k.cross_region_inference_id("ap-southeast-1", false)?,
|
||||
"cohere.command-text-v14:7:4k"
|
||||
);
|
||||
Ok(())
|
||||
@@ -850,10 +911,17 @@ mod tests {
|
||||
|
||||
// Custom model should return its name unchanged
|
||||
assert_eq!(
|
||||
custom_model.cross_region_inference_id("us-east-1")?,
|
||||
custom_model.cross_region_inference_id("us-east-1", false)?,
|
||||
"custom.my-model-v1:0"
|
||||
);
|
||||
|
||||
// Test that models without global support fall back to regional when allow_global is true
|
||||
assert_eq!(
|
||||
Model::AmazonNovaPro.cross_region_inference_id("us-east-1", true)?,
|
||||
"us.amazon.nova-pro-v1:0",
|
||||
"Nova Pro should fall back to regional profile even when allow_global is true"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -892,3 +960,28 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_inference_ids() -> anyhow::Result<()> {
|
||||
// Test global inference for models that support it when allow_global is true
|
||||
assert_eq!(
|
||||
Model::ClaudeSonnet4.cross_region_inference_id("us-east-1", true)?,
|
||||
"global.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::ClaudeSonnet4_5.cross_region_inference_id("eu-west-1", true)?,
|
||||
"global.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
);
|
||||
assert_eq!(
|
||||
Model::ClaudeHaiku4_5.cross_region_inference_id("ap-south-1", true)?,
|
||||
"global.anthropic.claude-haiku-4-5-20251001-v1:0"
|
||||
);
|
||||
|
||||
// Test that regional prefix is used when allow_global is false
|
||||
assert_eq!(
|
||||
Model::ClaudeSonnet4.cross_region_inference_id("us-east-1", false)?,
|
||||
"us.anthropic.claude-sonnet-4-20250514-v1:0"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1723,6 +1723,10 @@ impl ProtoClient for Client {
|
||||
fn is_via_collab(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn has_wsl_interop(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// prefix for the zed:// url scheme
|
||||
|
||||
@@ -31,18 +31,10 @@ pub struct PredictEditsRequest {
|
||||
/// Within `signatures`
|
||||
pub excerpt_parent: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub included_files: Vec<IncludedFile>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub signatures: Vec<Signature>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub referenced_declarations: Vec<ReferencedDeclaration>,
|
||||
pub related_files: Vec<RelatedFile>,
|
||||
pub events: Vec<Arc<Event>>,
|
||||
#[serde(default)]
|
||||
pub can_collect_data: bool,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty", default)]
|
||||
pub diagnostic_groups: Vec<DiagnosticGroup>,
|
||||
#[serde(skip_serializing_if = "is_default", default)]
|
||||
pub diagnostic_groups_truncated: bool,
|
||||
/// Info about the git repository state, only present when can_collect_data is true.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub git_info: Option<PredictEditsGitInfo>,
|
||||
@@ -58,7 +50,7 @@ pub struct PredictEditsRequest {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IncludedFile {
|
||||
pub struct RelatedFile {
|
||||
pub path: Arc<Path>,
|
||||
pub max_row: Line,
|
||||
pub excerpts: Vec<Excerpt>,
|
||||
@@ -72,11 +64,9 @@ pub struct Excerpt {
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum PromptFormat {
|
||||
MarkedExcerpt,
|
||||
LabeledSections,
|
||||
NumLinesUniDiff,
|
||||
/// XML old_tex/new_text
|
||||
OldTextNewText,
|
||||
/// Prompt format intended for use via zeta_cli
|
||||
/// Prompt format intended for use via edit_prediction_cli
|
||||
OnlySnippets,
|
||||
/// One-sentence instructions used in fine-tuned models
|
||||
Minimal,
|
||||
@@ -87,7 +77,7 @@ pub enum PromptFormat {
|
||||
}
|
||||
|
||||
impl PromptFormat {
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
|
||||
pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
|
||||
}
|
||||
|
||||
impl Default for PromptFormat {
|
||||
@@ -105,10 +95,7 @@ impl PromptFormat {
|
||||
impl std::fmt::Display for PromptFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
|
||||
PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
|
||||
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
|
||||
PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
|
||||
PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
|
||||
PromptFormat::Minimal => write!(f, "Minimal"),
|
||||
PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
|
||||
@@ -178,67 +165,6 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Signature {
|
||||
pub text: String,
|
||||
pub text_is_truncated: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub parent_index: Option<usize>,
|
||||
/// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
|
||||
/// file is implicitly the file that contains the descendant declaration or excerpt.
|
||||
pub range: Range<Line>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReferencedDeclaration {
|
||||
pub path: Arc<Path>,
|
||||
pub text: String,
|
||||
pub text_is_truncated: bool,
|
||||
/// Range of `text` within file, possibly truncated according to `text_is_truncated`
|
||||
pub range: Range<Line>,
|
||||
/// Range within `text`
|
||||
pub signature_range: Range<usize>,
|
||||
/// Index within `signatures`.
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub parent_index: Option<usize>,
|
||||
pub score_components: DeclarationScoreComponents,
|
||||
pub signature_score: f32,
|
||||
pub declaration_score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeclarationScoreComponents {
|
||||
pub is_same_file: bool,
|
||||
pub is_referenced_nearby: bool,
|
||||
pub is_referenced_in_breadcrumb: bool,
|
||||
pub reference_count: usize,
|
||||
pub same_file_declaration_count: usize,
|
||||
pub declaration_count: usize,
|
||||
pub reference_line_distance: u32,
|
||||
pub declaration_line_distance: u32,
|
||||
pub excerpt_vs_item_jaccard: f32,
|
||||
pub excerpt_vs_signature_jaccard: f32,
|
||||
pub adjacent_vs_item_jaccard: f32,
|
||||
pub adjacent_vs_signature_jaccard: f32,
|
||||
pub excerpt_vs_item_weighted_overlap: f32,
|
||||
pub excerpt_vs_signature_weighted_overlap: f32,
|
||||
pub adjacent_vs_item_weighted_overlap: f32,
|
||||
pub adjacent_vs_signature_weighted_overlap: f32,
|
||||
pub path_import_match_count: usize,
|
||||
pub wildcard_path_import_match_count: usize,
|
||||
pub import_similarity: f32,
|
||||
pub max_import_similarity: f32,
|
||||
pub normalized_import_similarity: f32,
|
||||
pub wildcard_import_similarity: f32,
|
||||
pub normalized_wildcard_import_similarity: f32,
|
||||
pub included_by_others: usize,
|
||||
pub includes_others: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsResponse {
|
||||
pub request_id: Uuid,
|
||||
@@ -262,10 +188,6 @@ pub struct Edit {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
|
||||
*value == T::default()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
|
||||
pub struct Point {
|
||||
pub line: Line,
|
||||
|
||||
@@ -15,9 +15,4 @@ path = "src/cloud_zeta2_prompt.rs"
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
indoc.workspace = true
|
||||
ordered-float.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
||||
@@ -1,20 +1,12 @@
|
||||
//! Zeta2 prompt planning and generation code shared with cloud.
|
||||
pub mod retrieval_prompt;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat,
|
||||
ReferencedDeclaration,
|
||||
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use serde::Serialize;
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
|
||||
use strum::{EnumIter, IntoEnumIterator};
|
||||
|
||||
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
|
||||
|
||||
@@ -24,69 +16,6 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
|
||||
|
||||
// TODO: use constants for markers?
|
||||
const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
|
||||
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
|
||||
|
||||
The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
|
||||
|
||||
Other code is provided for context, and `…` indicates when code has been skipped.
|
||||
|
||||
## Edit History
|
||||
|
||||
"};
|
||||
|
||||
const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
|
||||
|
||||
Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
|
||||
|
||||
The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
|
||||
|
||||
Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
|
||||
|
||||
<|current_section|>
|
||||
for i in 0..16 {
|
||||
println!("{i}");
|
||||
}
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
|
||||
# Instructions
|
||||
|
||||
You are an edit prediction agent in a code editor.
|
||||
Your job is to predict the next edit that the user will make,
|
||||
based on their last few edits and their current cursor location.
|
||||
|
||||
## Output Format
|
||||
|
||||
You must briefly explain your understanding of the user's goal, in one
|
||||
or two sentences, and then specify their next edit in the form of a
|
||||
unified diff, like this:
|
||||
|
||||
```
|
||||
--- a/src/myapp/cli.py
|
||||
+++ b/src/myapp/cli.py
|
||||
@@ ... @@
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
+from constants import LOG_LEVEL_WARNING
|
||||
@@ ... @@
|
||||
config.headless()
|
||||
config.set_interactive(false)
|
||||
-config.set_log_level(LOG_L)
|
||||
+config.set_log_level(LOG_LEVEL_WARNING)
|
||||
config.set_use_color(True)
|
||||
```
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
|
||||
|
||||
@@ -94,20 +23,6 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
|
||||
"#};
|
||||
|
||||
const UNIFIED_DIFF_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
Analyze the edit history and the files, then provide the unified diff for your predicted edits.
|
||||
Do not include the cursor marker in your output.
|
||||
Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
|
||||
Do not include line numbers in the hunk headers, use `@@ ... @@`.
|
||||
Removed lines begin with `-`.
|
||||
Added lines begin with `+`.
|
||||
Context lines begin with an extra space.
|
||||
Context and removed lines are used to match the target edit location, so make sure to include enough of them
|
||||
to uniquely identify it amongst all excerpts of code provided.
|
||||
"};
|
||||
|
||||
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
@@ -164,49 +79,25 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
|
||||
Remember that the edits in the edit history have already been applied.
|
||||
"#};
|
||||
|
||||
pub fn build_prompt(
|
||||
request: &predict_edits_v3::PredictEditsRequest,
|
||||
) -> Result<(String, SectionLabels)> {
|
||||
let mut section_labels = Default::default();
|
||||
|
||||
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
|
||||
let prompt_data = PromptData {
|
||||
events: request.events.clone(),
|
||||
cursor_point: request.cursor_point,
|
||||
cursor_path: request.excerpt_path.clone(),
|
||||
included_files: request.included_files.clone(),
|
||||
included_files: request.related_files.clone(),
|
||||
};
|
||||
match request.prompt_format {
|
||||
PromptFormat::MinimalQwen => {
|
||||
return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels));
|
||||
return Ok(MinimalQwenPrompt.render(&prompt_data));
|
||||
}
|
||||
PromptFormat::SeedCoder1120 => {
|
||||
return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels));
|
||||
return Ok(SeedCoder1120Prompt.render(&prompt_data));
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let mut insertions = match request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt => vec![
|
||||
(
|
||||
Point {
|
||||
line: request.excerpt_line_range.start,
|
||||
column: 0,
|
||||
},
|
||||
EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
|
||||
),
|
||||
(request.cursor_point, CURSOR_MARKER),
|
||||
(
|
||||
Point {
|
||||
line: request.excerpt_line_range.end,
|
||||
column: 0,
|
||||
},
|
||||
EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
|
||||
),
|
||||
],
|
||||
PromptFormat::LabeledSections
|
||||
| PromptFormat::NumLinesUniDiff
|
||||
| PromptFormat::Minimal
|
||||
| PromptFormat::OldTextNewText => {
|
||||
let insertions = match request.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
|
||||
vec![(request.cursor_point, CURSOR_MARKER)]
|
||||
}
|
||||
PromptFormat::OnlySnippets => vec![],
|
||||
@@ -215,9 +106,6 @@ pub fn build_prompt(
|
||||
};
|
||||
|
||||
let mut prompt = match request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OnlySnippets => String::new(),
|
||||
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
|
||||
@@ -247,7 +135,7 @@ pub fn build_prompt(
|
||||
You can only edit exactly this part of the file.
|
||||
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
|
||||
"},
|
||||
PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {"
|
||||
PromptFormat::OldTextNewText => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
Here is some excerpts of code that you should take into account to predict the next edit.
|
||||
@@ -263,64 +151,51 @@ pub fn build_prompt(
|
||||
|
||||
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
|
||||
"},
|
||||
_ => indoc! {"
|
||||
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history have been applied.
|
||||
"},
|
||||
"}
|
||||
}
|
||||
};
|
||||
|
||||
prompt.push_str(excerpts_preamble);
|
||||
prompt.push('\n');
|
||||
|
||||
if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
|
||||
let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?;
|
||||
section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?;
|
||||
} else {
|
||||
if request.prompt_format == PromptFormat::LabeledSections {
|
||||
anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
|
||||
}
|
||||
|
||||
let include_line_numbers = matches!(
|
||||
request.prompt_format,
|
||||
PromptFormat::NumLinesUniDiff | PromptFormat::Minimal
|
||||
);
|
||||
for related_file in &request.included_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
|
||||
for related_file in &request.related_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match request.prompt_format {
|
||||
PromptFormat::NumLinesUniDiff => {
|
||||
prompt.push_str(UNIFIED_DIFF_REMINDER);
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
|
||||
}
|
||||
@@ -330,7 +205,7 @@ pub fn build_prompt(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok((prompt, section_labels))
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
|
||||
@@ -444,476 +319,11 @@ pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>])
|
||||
writeln!(output, "`````\n").unwrap();
|
||||
}
|
||||
|
||||
pub struct SyntaxBasedPrompt<'a> {
|
||||
request: &'a predict_edits_v3::PredictEditsRequest,
|
||||
/// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
|
||||
/// `to_prompt_string`.
|
||||
snippets: Vec<PlannedSnippet<'a>>,
|
||||
budget_used: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PlannedSnippet<'a> {
|
||||
path: Arc<Path>,
|
||||
range: Range<Line>,
|
||||
text: &'a str,
|
||||
// TODO: Indicate this in the output
|
||||
#[allow(dead_code)]
|
||||
text_is_truncated: bool,
|
||||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize)]
|
||||
pub struct SectionLabels {
|
||||
pub excerpt_index: usize,
|
||||
pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
|
||||
}
|
||||
|
||||
impl<'a> SyntaxBasedPrompt<'a> {
|
||||
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
|
||||
///
|
||||
/// Initializes a priority queue by populating it with each snippet, finding the
|
||||
/// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
|
||||
/// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
|
||||
/// the cost of upgrade.
|
||||
///
|
||||
/// TODO: Implement an early halting condition. One option might be to have another priority
|
||||
/// queue where the score is the size, and update it accordingly. Another option might be to
|
||||
/// have some simpler heuristic like bailing after N failed insertions, or based on how much
|
||||
/// budget is left.
|
||||
///
|
||||
/// TODO: Has the current known sources of imprecision:
|
||||
///
|
||||
/// * Does not consider snippet overlap when ranking. For example, it might add a field to the
|
||||
/// plan even though the containing struct is already included.
|
||||
///
|
||||
/// * Does not consider cost of signatures when ranking snippets - this is tricky since
|
||||
/// signatures may be shared by multiple snippets.
|
||||
///
|
||||
/// * Does not include file paths / other text when considering max_bytes.
|
||||
pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
|
||||
let mut this = Self {
|
||||
request,
|
||||
snippets: Vec::new(),
|
||||
budget_used: request.excerpt.len(),
|
||||
};
|
||||
let mut included_parents = FxHashSet::default();
|
||||
let additional_parents = this.additional_parent_signatures(
|
||||
&request.excerpt_path,
|
||||
request.excerpt_parent,
|
||||
&included_parents,
|
||||
)?;
|
||||
this.add_parents(&mut included_parents, additional_parents);
|
||||
|
||||
let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
|
||||
|
||||
if this.budget_used > max_bytes {
|
||||
return Err(anyhow!(
|
||||
"Excerpt + signatures size of {} already exceeds budget of {}",
|
||||
this.budget_used,
|
||||
max_bytes
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct QueueEntry {
|
||||
score_density: OrderedFloat<f32>,
|
||||
declaration_index: usize,
|
||||
style: DeclarationStyle,
|
||||
}
|
||||
|
||||
// Initialize priority queue with the best score for each snippet.
|
||||
let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
|
||||
for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
|
||||
let (style, score_density) = DeclarationStyle::iter()
|
||||
.map(|style| {
|
||||
(
|
||||
style,
|
||||
OrderedFloat(declaration_score_density(&declaration, style)),
|
||||
)
|
||||
})
|
||||
.max_by_key(|(_, score_density)| *score_density)
|
||||
.unwrap();
|
||||
queue.push(QueueEntry {
|
||||
score_density,
|
||||
declaration_index,
|
||||
style,
|
||||
});
|
||||
}
|
||||
|
||||
// Knapsack selection loop
|
||||
while let Some(queue_entry) = queue.pop() {
|
||||
let Some(declaration) = request
|
||||
.referenced_declarations
|
||||
.get(queue_entry.declaration_index)
|
||||
else {
|
||||
return Err(anyhow!(
|
||||
"Invalid declaration index {}",
|
||||
queue_entry.declaration_index
|
||||
));
|
||||
};
|
||||
|
||||
let mut additional_bytes = declaration_size(declaration, queue_entry.style);
|
||||
if this.budget_used + additional_bytes > max_bytes {
|
||||
continue;
|
||||
}
|
||||
|
||||
let additional_parents = this.additional_parent_signatures(
|
||||
&declaration.path,
|
||||
declaration.parent_index,
|
||||
&mut included_parents,
|
||||
)?;
|
||||
additional_bytes += additional_parents
|
||||
.iter()
|
||||
.map(|(_, snippet)| snippet.text.len())
|
||||
.sum::<usize>();
|
||||
if this.budget_used + additional_bytes > max_bytes {
|
||||
continue;
|
||||
}
|
||||
|
||||
this.budget_used += additional_bytes;
|
||||
this.add_parents(&mut included_parents, additional_parents);
|
||||
let planned_snippet = match queue_entry.style {
|
||||
DeclarationStyle::Signature => {
|
||||
let Some(text) = declaration.text.get(declaration.signature_range.clone())
|
||||
else {
|
||||
return Err(anyhow!(
|
||||
"Invalid declaration signature_range {:?} with text.len() = {}",
|
||||
declaration.signature_range,
|
||||
declaration.text.len()
|
||||
));
|
||||
};
|
||||
let signature_start_line = declaration.range.start
|
||||
+ Line(
|
||||
declaration.text[..declaration.signature_range.start]
|
||||
.lines()
|
||||
.count() as u32,
|
||||
);
|
||||
let signature_end_line = signature_start_line
|
||||
+ Line(
|
||||
declaration.text
|
||||
[declaration.signature_range.start..declaration.signature_range.end]
|
||||
.lines()
|
||||
.count() as u32,
|
||||
);
|
||||
let range = signature_start_line..signature_end_line;
|
||||
|
||||
PlannedSnippet {
|
||||
path: declaration.path.clone(),
|
||||
range,
|
||||
text,
|
||||
text_is_truncated: declaration.text_is_truncated,
|
||||
}
|
||||
}
|
||||
DeclarationStyle::Declaration => PlannedSnippet {
|
||||
path: declaration.path.clone(),
|
||||
range: declaration.range.clone(),
|
||||
text: &declaration.text,
|
||||
text_is_truncated: declaration.text_is_truncated,
|
||||
},
|
||||
};
|
||||
this.snippets.push(planned_snippet);
|
||||
|
||||
// When a Signature is consumed, insert an entry for Definition style.
|
||||
if queue_entry.style == DeclarationStyle::Signature {
|
||||
let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_size =
|
||||
declaration_size(&declaration, DeclarationStyle::Declaration);
|
||||
let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
|
||||
let declaration_score =
|
||||
declaration_score(&declaration, DeclarationStyle::Declaration);
|
||||
|
||||
let score_diff = declaration_score - signature_score;
|
||||
let size_diff = declaration_size.saturating_sub(signature_size);
|
||||
if score_diff > 0.0001 && size_diff > 0 {
|
||||
queue.push(QueueEntry {
|
||||
declaration_index: queue_entry.declaration_index,
|
||||
score_density: OrderedFloat(score_diff / (size_diff as f32)),
|
||||
style: DeclarationStyle::Declaration,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(this)
|
||||
}
|
||||
|
||||
fn add_parents(
|
||||
&mut self,
|
||||
included_parents: &mut FxHashSet<usize>,
|
||||
snippets: Vec<(usize, PlannedSnippet<'a>)>,
|
||||
) {
|
||||
for (parent_index, snippet) in snippets {
|
||||
included_parents.insert(parent_index);
|
||||
self.budget_used += snippet.text.len();
|
||||
self.snippets.push(snippet);
|
||||
}
|
||||
}
|
||||
|
||||
fn additional_parent_signatures(
|
||||
&self,
|
||||
path: &Arc<Path>,
|
||||
parent_index: Option<usize>,
|
||||
included_parents: &FxHashSet<usize>,
|
||||
) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
|
||||
let mut results = Vec::new();
|
||||
self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn additional_parent_signatures_impl(
|
||||
&self,
|
||||
path: &Arc<Path>,
|
||||
parent_index: Option<usize>,
|
||||
included_parents: &FxHashSet<usize>,
|
||||
results: &mut Vec<(usize, PlannedSnippet<'a>)>,
|
||||
) -> Result<()> {
|
||||
let Some(parent_index) = parent_index else {
|
||||
return Ok(());
|
||||
};
|
||||
if included_parents.contains(&parent_index) {
|
||||
return Ok(());
|
||||
}
|
||||
let Some(parent_signature) = self.request.signatures.get(parent_index) else {
|
||||
return Err(anyhow!("Invalid parent index {}", parent_index));
|
||||
};
|
||||
results.push((
|
||||
parent_index,
|
||||
PlannedSnippet {
|
||||
path: path.clone(),
|
||||
range: parent_signature.range.clone(),
|
||||
text: &parent_signature.text,
|
||||
text_is_truncated: parent_signature.text_is_truncated,
|
||||
},
|
||||
));
|
||||
self.additional_parent_signatures_impl(
|
||||
path,
|
||||
parent_signature.parent_index,
|
||||
included_parents,
|
||||
results,
|
||||
)
|
||||
}
|
||||
|
||||
/// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
|
||||
/// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
|
||||
/// chunks.
|
||||
pub fn write(
|
||||
&'a self,
|
||||
excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
|
||||
prompt: &mut String,
|
||||
) -> Result<SectionLabels> {
|
||||
let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
|
||||
FxHashMap::default();
|
||||
for snippet in &self.snippets {
|
||||
file_to_snippets
|
||||
.entry(&snippet.path)
|
||||
.or_default()
|
||||
.push(snippet);
|
||||
}
|
||||
|
||||
// Reorder so that file with cursor comes last
|
||||
let mut file_snippets = Vec::new();
|
||||
let mut excerpt_file_snippets = Vec::new();
|
||||
for (file_path, snippets) in file_to_snippets {
|
||||
if file_path == self.request.excerpt_path.as_ref() {
|
||||
excerpt_file_snippets = snippets;
|
||||
} else {
|
||||
file_snippets.push((file_path, snippets, false));
|
||||
}
|
||||
}
|
||||
let excerpt_snippet = PlannedSnippet {
|
||||
path: self.request.excerpt_path.clone(),
|
||||
range: self.request.excerpt_line_range.clone(),
|
||||
text: &self.request.excerpt,
|
||||
text_is_truncated: false,
|
||||
};
|
||||
excerpt_file_snippets.push(&excerpt_snippet);
|
||||
file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
|
||||
|
||||
let section_labels =
|
||||
self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?;
|
||||
|
||||
Ok(section_labels)
|
||||
}
|
||||
|
||||
fn push_file_snippets(
|
||||
&self,
|
||||
output: &mut String,
|
||||
excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
|
||||
file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
|
||||
) -> Result<SectionLabels> {
|
||||
let mut section_ranges = Vec::new();
|
||||
let mut excerpt_index = None;
|
||||
|
||||
for (file_path, mut snippets, is_excerpt_file) in file_snippets {
|
||||
snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
|
||||
|
||||
// TODO: What if the snippets get expanded too large to be editable?
|
||||
let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
|
||||
let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
|
||||
for snippet in snippets {
|
||||
if let Some((_, current_snippet_range)) = current_snippet.as_mut()
|
||||
&& snippet.range.start <= current_snippet_range.end
|
||||
{
|
||||
current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
|
||||
continue;
|
||||
}
|
||||
if let Some(current_snippet) = current_snippet.take() {
|
||||
disjoint_snippets.push(current_snippet);
|
||||
}
|
||||
current_snippet = Some((snippet, snippet.range.clone()));
|
||||
}
|
||||
if let Some(current_snippet) = current_snippet.take() {
|
||||
disjoint_snippets.push(current_snippet);
|
||||
}
|
||||
|
||||
writeln!(output, "`````path={}", file_path.display()).ok();
|
||||
let mut skipped_last_snippet = false;
|
||||
for (snippet, range) in disjoint_snippets {
|
||||
let section_index = section_ranges.len();
|
||||
|
||||
match self.request.prompt_format {
|
||||
PromptFormat::MarkedExcerpt
|
||||
| PromptFormat::OnlySnippets
|
||||
| PromptFormat::OldTextNewText
|
||||
| PromptFormat::Minimal
|
||||
| PromptFormat::NumLinesUniDiff => {
|
||||
if range.start.0 > 0 && !skipped_last_snippet {
|
||||
output.push_str("…\n");
|
||||
}
|
||||
}
|
||||
PromptFormat::LabeledSections => {
|
||||
if is_excerpt_file
|
||||
&& range.start <= self.request.excerpt_line_range.start
|
||||
&& range.end >= self.request.excerpt_line_range.end
|
||||
{
|
||||
writeln!(output, "<|current_section|>").ok();
|
||||
} else {
|
||||
writeln!(output, "<|section_{}|>", section_index).ok();
|
||||
}
|
||||
}
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
}
|
||||
|
||||
let push_full_snippet = |output: &mut String| {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
for (i, line) in snippet.text.lines().enumerate() {
|
||||
writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
|
||||
}
|
||||
} else {
|
||||
output.push_str(&snippet.text);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
if is_excerpt_file {
|
||||
if self.request.prompt_format == PromptFormat::OnlySnippets {
|
||||
if range.start >= self.request.excerpt_line_range.start
|
||||
&& range.end <= self.request.excerpt_line_range.end
|
||||
{
|
||||
skipped_last_snippet = true;
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
output.push_str(snippet.text);
|
||||
}
|
||||
} else if !excerpt_file_insertions.is_empty() {
|
||||
let lines = snippet.text.lines().collect::<Vec<_>>();
|
||||
let push_line = |output: &mut String, line_ix: usize| {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
|
||||
}
|
||||
anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
|
||||
};
|
||||
let mut last_line_ix = 0;
|
||||
let mut insertion_ix = 0;
|
||||
while insertion_ix < excerpt_file_insertions.len() {
|
||||
let (point, insertion) = &excerpt_file_insertions[insertion_ix];
|
||||
let found = point.line >= range.start && point.line <= range.end;
|
||||
if found {
|
||||
excerpt_index = Some(section_index);
|
||||
let insertion_line_ix = (point.line.0 - range.start.0) as usize;
|
||||
for line_ix in last_line_ix..insertion_line_ix {
|
||||
push_line(output, line_ix)?;
|
||||
}
|
||||
if let Some(next_line) = lines.get(insertion_line_ix) {
|
||||
if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
|
||||
write!(
|
||||
output,
|
||||
"{}|",
|
||||
insertion_line_ix as u32 + range.start.0 + 1
|
||||
)?
|
||||
}
|
||||
output.push_str(&next_line[..point.column as usize]);
|
||||
output.push_str(insertion);
|
||||
writeln!(output, "{}", &next_line[point.column as usize..])?;
|
||||
} else {
|
||||
writeln!(output, "{}", insertion)?;
|
||||
}
|
||||
last_line_ix = insertion_line_ix + 1;
|
||||
excerpt_file_insertions.remove(insertion_ix);
|
||||
continue;
|
||||
}
|
||||
insertion_ix += 1;
|
||||
}
|
||||
skipped_last_snippet = false;
|
||||
for line_ix in last_line_ix..lines.len() {
|
||||
push_line(output, line_ix)?;
|
||||
}
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
push_full_snippet(output)?;
|
||||
}
|
||||
} else {
|
||||
skipped_last_snippet = false;
|
||||
push_full_snippet(output)?;
|
||||
}
|
||||
|
||||
section_ranges.push((snippet.path.clone(), range));
|
||||
}
|
||||
|
||||
output.push_str("`````\n\n");
|
||||
}
|
||||
|
||||
Ok(SectionLabels {
|
||||
// TODO: Clean this up
|
||||
excerpt_index: match self.request.prompt_format {
|
||||
PromptFormat::OnlySnippets => 0,
|
||||
_ => excerpt_index.context("bug: no snippet found for excerpt")?,
|
||||
},
|
||||
section_ranges,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
declaration_score(declaration, style) / declaration_size(declaration, style) as f32
|
||||
}
|
||||
|
||||
fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
|
||||
match style {
|
||||
DeclarationStyle::Signature => declaration.signature_score,
|
||||
DeclarationStyle::Declaration => declaration.declaration_score,
|
||||
}
|
||||
}
|
||||
|
||||
fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
|
||||
match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
}
|
||||
}
|
||||
|
||||
struct PromptData {
|
||||
events: Vec<Arc<Event>>,
|
||||
cursor_point: Point,
|
||||
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
|
||||
included_files: Vec<IncludedFile>,
|
||||
included_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -1051,7 +461,7 @@ impl SeedCoder1120Prompt {
|
||||
context
|
||||
}
|
||||
|
||||
fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String {
|
||||
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
|
||||
let mut buf = String::new();
|
||||
const FIM_SUFFIX: &str = "<[fim-suffix]>";
|
||||
const FIM_PREFIX: &str = "<[fim-prefix]>";
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{self, Excerpt};
|
||||
use indoc::indoc;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::{push_events, write_codeblock};
|
||||
|
||||
pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
|
||||
let mut prompt = SEARCH_INSTRUCTIONS.to_string();
|
||||
|
||||
if !request.events.is_empty() {
|
||||
writeln!(&mut prompt, "\n## User Edits\n\n")?;
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
writeln!(&mut prompt, "## Cursor context\n")?;
|
||||
write_codeblock(
|
||||
&request.excerpt_path,
|
||||
&[Excerpt {
|
||||
start_line: request.excerpt_line_range.start,
|
||||
text: request.excerpt.into(),
|
||||
}],
|
||||
&[],
|
||||
request.cursor_file_max_row,
|
||||
true,
|
||||
&mut prompt,
|
||||
);
|
||||
|
||||
writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
/// Search for relevant code
|
||||
///
|
||||
/// For the best results, run multiple queries at once with a single invocation of this tool.
|
||||
#[derive(Clone, Deserialize, Serialize, JsonSchema)]
|
||||
pub struct SearchToolInput {
|
||||
/// An array of queries to run for gathering context relevant to the next prediction
|
||||
#[schemars(length(max = 3))]
|
||||
#[serde(deserialize_with = "deserialize_queries")]
|
||||
pub queries: Box<[SearchToolQuery]>,
|
||||
}
|
||||
|
||||
fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum QueryCollection {
|
||||
Array(Box<[SearchToolQuery]>),
|
||||
DoubleArray(Box<[Box<[SearchToolQuery]>]>),
|
||||
Single(SearchToolQuery),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MaybeDoubleEncoded {
|
||||
SingleEncoded(QueryCollection),
|
||||
DoubleEncoded(String),
|
||||
}
|
||||
|
||||
let result = MaybeDoubleEncoded::deserialize(deserializer)?;
|
||||
|
||||
let normalized = match result {
|
||||
MaybeDoubleEncoded::SingleEncoded(value) => value,
|
||||
MaybeDoubleEncoded::DoubleEncoded(value) => {
|
||||
serde_json::from_str(&value).map_err(D::Error::custom)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(match normalized {
|
||||
QueryCollection::Array(items) => items,
|
||||
QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
|
||||
QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search for relevant code by path, syntax hierarchy, and content.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
|
||||
pub struct SearchToolQuery {
|
||||
/// 1. A glob pattern to match file paths in the codebase to search in.
|
||||
pub glob: String,
|
||||
/// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
|
||||
///
|
||||
/// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
|
||||
///
|
||||
/// Example: Searching for a `User` class
|
||||
/// ["class\s+User"]
|
||||
///
|
||||
/// Example: Searching for a `get_full_name` method under a `User` class
|
||||
/// ["class\s+User", "def\sget_full_name"]
|
||||
///
|
||||
/// Skip this field to match on content alone.
|
||||
#[schemars(length(max = 3))]
|
||||
#[serde(default)]
|
||||
pub syntax_node: Vec<String>,
|
||||
/// 3. An optional regular expression to match the final content that should appear in the results.
|
||||
///
|
||||
/// - Content will be matched within all lines of the matched syntax nodes.
|
||||
/// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
|
||||
/// - If no syntax node regexes are provided, the content will be matched within the entire file.
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
pub const TOOL_NAME: &str = "search";
|
||||
|
||||
const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are part of an edit prediction system in a code editor.
|
||||
Your role is to search for code that will serve as context for predicting the next edit.
|
||||
|
||||
- Analyze the user's recent edits and current cursor context
|
||||
- Use the `search` tool to find code that is relevant for predicting the next edit
|
||||
- Focus on finding:
|
||||
- Code patterns that might need similar changes based on the recent edits
|
||||
- Functions, variables, types, and constants referenced in the current cursor context
|
||||
- Related implementations, usages, or dependencies that may require consistent updates
|
||||
- How items defined in the cursor excerpt are used or altered
|
||||
- You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
|
||||
- Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
|
||||
- Avoid using wildcard globs if you already know the file path of the content you're looking for
|
||||
"#};
|
||||
|
||||
const TOOL_USE_REMINDER: &str = indoc! {"
|
||||
--
|
||||
Analyze the user's intent in one to two sentences, then call the `search` tool.
|
||||
"};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_queries() {
|
||||
let single_query_json = indoc! {r#"{
|
||||
"queries": {
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 1);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
|
||||
let flat_json = indoc! {r#"{
|
||||
"queries": [
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 2);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(flat_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(flat_input.queries[1].content, None);
|
||||
|
||||
let nested_json = indoc! {r#"{
|
||||
"queries": [
|
||||
[
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
]
|
||||
}"#};
|
||||
|
||||
let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
|
||||
|
||||
assert_eq!(nested_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(nested_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(nested_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(nested_input.queries[1].content, None);
|
||||
|
||||
let double_encoded_queries = serde_json::to_string(&json!({
|
||||
"queries": serde_json::to_string(&json!([
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
])).unwrap()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let double_encoded_input: SearchToolInput =
|
||||
serde_json::from_str(&double_encoded_queries).unwrap();
|
||||
|
||||
assert_eq!(double_encoded_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(
|
||||
double_encoded_input.queries[0].content,
|
||||
Some("assert".to_string())
|
||||
);
|
||||
assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(double_encoded_input.queries[1].content, None);
|
||||
|
||||
// ### ERROR Switching from var declarations to lexical declarations [RUN 073]
|
||||
// invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ path = "src/codestral.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
|
||||
use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::{App, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
@@ -43,17 +43,17 @@ impl CurrentCompletion {
|
||||
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
|
||||
/// Returns None if the user's edits conflict with the predicted edits.
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CodestralCompletionProvider {
|
||||
pub struct CodestralEditPredictionDelegate {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
pending_request: Option<Task<Result<()>>>,
|
||||
current_completion: Option<CurrentCompletion>,
|
||||
}
|
||||
|
||||
impl CodestralCompletionProvider {
|
||||
impl CodestralEditPredictionDelegate {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
@@ -165,7 +165,7 @@ impl CodestralCompletionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for CodestralCompletionProvider {
|
||||
impl EditPredictionDelegate for CodestralEditPredictionDelegate {
|
||||
fn name() -> &'static str {
|
||||
"codestral"
|
||||
}
|
||||
@@ -174,7 +174,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
|
||||
"Codestral"
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
fn show_predictions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -239,7 +239,6 @@ impl EditPredictionProvider for CodestralCompletionProvider {
|
||||
cursor_point,
|
||||
&snapshot,
|
||||
&EXCERPT_OPTIONS,
|
||||
None,
|
||||
)
|
||||
.context("Line containing cursor doesn't fit in excerpt max bytes")?;
|
||||
|
||||
|
||||
@@ -121,6 +121,8 @@ CREATE TABLE "project_repositories" (
|
||||
"merge_message" VARCHAR,
|
||||
"branch_summary" VARCHAR,
|
||||
"head_commit_details" VARCHAR,
|
||||
"remote_upstream_url" VARCHAR,
|
||||
"remote_origin_url" VARCHAR,
|
||||
PRIMARY KEY (project_id, id)
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE "project_repositories" ADD COLUMN "remote_upstream_url" VARCHAR;
|
||||
ALTER TABLE "project_repositories" ADD COLUMN "remote_origin_url" VARCHAR;
|
||||
@@ -362,6 +362,8 @@ impl Database {
|
||||
entry_ids: ActiveValue::set("[]".into()),
|
||||
head_commit_details: ActiveValue::set(None),
|
||||
merge_message: ActiveValue::set(None),
|
||||
remote_upstream_url: ActiveValue::set(None),
|
||||
remote_origin_url: ActiveValue::set(None),
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -511,6 +513,8 @@ impl Database {
|
||||
serde_json::to_string(&update.current_merge_conflicts).unwrap(),
|
||||
)),
|
||||
merge_message: ActiveValue::set(update.merge_message.clone()),
|
||||
remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()),
|
||||
remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()),
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
@@ -1005,6 +1009,8 @@ impl Database {
|
||||
is_last_update: true,
|
||||
merge_message: db_repository_entry.merge_message,
|
||||
stash_entries: Vec::new(),
|
||||
remote_upstream_url: db_repository_entry.remote_upstream_url.clone(),
|
||||
remote_origin_url: db_repository_entry.remote_origin_url.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -796,6 +796,8 @@ impl Database {
|
||||
is_last_update: true,
|
||||
merge_message: db_repository.merge_message,
|
||||
stash_entries: Vec::new(),
|
||||
remote_upstream_url: db_repository.remote_upstream_url.clone(),
|
||||
remote_origin_url: db_repository.remote_origin_url.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ pub struct Model {
|
||||
pub branch_summary: Option<String>,
|
||||
// A JSON object representing the current Head commit values
|
||||
pub head_commit_details: Option<String>,
|
||||
pub remote_upstream_url: Option<String>,
|
||||
pub remote_origin_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
||||
@@ -469,6 +469,8 @@ impl Server {
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
|
||||
@@ -3518,7 +3518,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
|
||||
.into_iter()
|
||||
.map(|(sha, message)| (sha.parse().unwrap(), message.into()))
|
||||
.collect(),
|
||||
remote_url: Some("git@github.com:zed-industries/zed.git".to_string()),
|
||||
};
|
||||
client_a.fs().set_blame_for_repo(
|
||||
Path::new(path!("/my-repo/.git")),
|
||||
@@ -3603,10 +3602,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
|
||||
for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() {
|
||||
let details = blame.details_for_entry(*buffer, entry).unwrap();
|
||||
assert_eq!(details.message, format!("message for idx-{}", idx));
|
||||
assert_eq!(
|
||||
details.permalink.unwrap().to_string(),
|
||||
format!("https://github.com/zed-industries/zed/commit/{}", entry.sha)
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,7 +33,7 @@ fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod copilot_chat;
|
||||
mod copilot_completion_provider;
|
||||
mod copilot_edit_prediction_delegate;
|
||||
pub mod copilot_responses;
|
||||
pub mod request;
|
||||
mod sign_in;
|
||||
@@ -46,7 +46,7 @@ use util::rel_path::RelPath;
|
||||
use util::{ResultExt, fs::remove_matching};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::copilot_completion_provider::CopilotCompletionProvider;
|
||||
pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
|
||||
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
|
||||
|
||||
actions!(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{Completion, Copilot};
|
||||
use anyhow::Result;
|
||||
use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
|
||||
use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
|
||||
use settings::Settings;
|
||||
@@ -8,7 +8,7 @@ use std::{path::Path, time::Duration};
|
||||
|
||||
pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||
|
||||
pub struct CopilotCompletionProvider {
|
||||
pub struct CopilotEditPredictionDelegate {
|
||||
cycled: bool,
|
||||
buffer_id: Option<EntityId>,
|
||||
completions: Vec<Completion>,
|
||||
@@ -19,7 +19,7 @@ pub struct CopilotCompletionProvider {
|
||||
copilot: Entity<Copilot>,
|
||||
}
|
||||
|
||||
impl CopilotCompletionProvider {
|
||||
impl CopilotEditPredictionDelegate {
|
||||
pub fn new(copilot: Entity<Copilot>) -> Self {
|
||||
Self {
|
||||
cycled: false,
|
||||
@@ -47,7 +47,7 @@ impl CopilotCompletionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
impl EditPredictionDelegate for CopilotEditPredictionDelegate {
|
||||
fn name() -> &'static str {
|
||||
"copilot"
|
||||
}
|
||||
@@ -56,7 +56,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
"Copilot"
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
fn show_predictions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -314,7 +314,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -546,7 +546,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -670,7 +670,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -753,7 +753,7 @@ mod tests {
|
||||
window.focus(&editor.focus_handle(cx));
|
||||
})
|
||||
.unwrap();
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
editor
|
||||
.update(cx, |editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
@@ -848,7 +848,7 @@ mod tests {
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
});
|
||||
@@ -1000,7 +1000,7 @@ mod tests {
|
||||
window.focus(&editor.focus_handle(cx))
|
||||
})
|
||||
.unwrap();
|
||||
let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
|
||||
let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
|
||||
editor
|
||||
.update(cx, |editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
|
||||
@@ -387,7 +387,7 @@ pub fn init(cx: &mut App) {
|
||||
window.on_action(
|
||||
TypeId::of::<editor::actions::EvaluateSelectedText>(),
|
||||
move |_, _, window, cx| {
|
||||
maybe!({
|
||||
let status = maybe!({
|
||||
let text = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let range = editor
|
||||
@@ -411,7 +411,13 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
state.session().update(cx, |session, cx| {
|
||||
session
|
||||
.evaluate(text, None, stack_id, None, cx)
|
||||
.evaluate(
|
||||
text,
|
||||
Some(dap::EvaluateArgumentsContext::Repl),
|
||||
stack_id,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
});
|
||||
});
|
||||
@@ -419,6 +425,9 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
Some(())
|
||||
});
|
||||
if status.is_some() {
|
||||
cx.stop_propagation();
|
||||
}
|
||||
},
|
||||
);
|
||||
})
|
||||
|
||||
@@ -1023,7 +1023,7 @@ impl DebugDelegate {
|
||||
Some(TaskSourceKind::Lsp { language_name, .. }) => {
|
||||
Some(format!("LSP: {language_name}"))
|
||||
}
|
||||
Some(TaskSourceKind::Language { name }) => Some(format!("Lang: {name}")),
|
||||
Some(TaskSourceKind::Language { name }) => Some(format!("Language: {name}")),
|
||||
_ => context.clone().and_then(|ctx| {
|
||||
ctx.task_context
|
||||
.task_variables
|
||||
|
||||
@@ -11,7 +11,69 @@ workspace = true
|
||||
[lib]
|
||||
path = "src/edit_prediction.rs"
|
||||
|
||||
[features]
|
||||
eval-support = []
|
||||
|
||||
[dependencies]
|
||||
ai_onboarding.workspace = true
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
db.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
open_ai.workspace = true
|
||||
postage.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
release_channel.workspace = true
|
||||
semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
thiserror.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
cloud_api_types.workspace = true
|
||||
cloud_llm_client = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
lsp.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1806
crates/edit_prediction/src/edit_prediction_tests.rs
Normal file
1806
crates/edit_prediction/src/edit_prediction_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -99,7 +99,7 @@ pub struct EditPrediction {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EditPredictionInputs {
|
||||
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
|
||||
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
|
||||
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
|
||||
pub cursor_path: Arc<Path>,
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
@@ -49,6 +50,7 @@ impl SweepAi {
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
@@ -120,6 +122,19 @@ impl SweepAi {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let retrieval_chunks = related_files
|
||||
.iter()
|
||||
.flat_map(|related_file| {
|
||||
related_file.excerpts.iter().map(|excerpt| FileChunk {
|
||||
file_path: related_file.path.path.as_unix_str().to_string(),
|
||||
start_line: excerpt.point_range.start.row as usize,
|
||||
end_line: excerpt.point_range.end.row as usize,
|
||||
content: excerpt.text.to_string(),
|
||||
timestamp: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
@@ -168,7 +183,7 @@ impl SweepAi {
|
||||
multiple_suggestions: false,
|
||||
branch: None,
|
||||
file_chunks,
|
||||
retrieval_chunks: vec![],
|
||||
retrieval_chunks,
|
||||
recent_user_actions: vec![],
|
||||
use_bytes: true,
|
||||
// TODO
|
||||
@@ -182,7 +197,7 @@ impl SweepAi {
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
@@ -320,7 +335,7 @@ struct AutocompleteRequest {
|
||||
pub cursor_position: usize,
|
||||
pub original_file_contents: String,
|
||||
pub file_chunks: Vec<FileChunk>,
|
||||
pub retrieval_chunks: Vec<RetrievalChunk>,
|
||||
pub retrieval_chunks: Vec<FileChunk>,
|
||||
pub recent_user_actions: Vec<UserAction>,
|
||||
pub multiple_suggestions: bool,
|
||||
pub privacy_mode_enabled: bool,
|
||||
@@ -337,15 +352,6 @@ struct FileChunk {
|
||||
pub timestamp: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct RetrievalChunk {
|
||||
pub file_path: String,
|
||||
pub start_line: usize,
|
||||
pub end_line: usize,
|
||||
pub content: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct UserAction {
|
||||
pub action_type: ActionType,
|
||||
@@ -1,55 +1,56 @@
|
||||
use std::{cmp, sync::Arc, time::Duration};
|
||||
use std::{cmp, sync::Arc};
|
||||
|
||||
use client::{Client, UserStore};
|
||||
use cloud_llm_client::EditPredictionRejectReason;
|
||||
use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
|
||||
use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
|
||||
use gpui::{App, Entity, prelude::*};
|
||||
use language::ToPoint as _;
|
||||
use language::{Buffer, ToPoint as _};
|
||||
use project::Project;
|
||||
|
||||
use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
|
||||
use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
|
||||
|
||||
pub struct ZetaEditPredictionProvider {
|
||||
zeta: Entity<Zeta>,
|
||||
pub struct ZedEditPredictionDelegate {
|
||||
store: Entity<EditPredictionStore>,
|
||||
project: Entity<Project>,
|
||||
singleton_buffer: Option<Entity<Buffer>>,
|
||||
}
|
||||
|
||||
impl ZetaEditPredictionProvider {
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
|
||||
impl ZedEditPredictionDelegate {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
singleton_buffer: Option<Entity<Buffer>>,
|
||||
client: &Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let zeta = Zeta::global(client, user_store, cx);
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.register_project(&project, cx);
|
||||
let store = EditPredictionStore::global(client, user_store, cx);
|
||||
store.update(cx, |store, cx| {
|
||||
store.register_project(&project, cx);
|
||||
});
|
||||
|
||||
cx.observe(&zeta, |_this, _zeta, cx| {
|
||||
cx.observe(&store, |_this, _ep_store, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
project: project,
|
||||
zeta,
|
||||
store: store,
|
||||
singleton_buffer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
fn name() -> &'static str {
|
||||
"zed-predict2"
|
||||
"zed-predict"
|
||||
}
|
||||
|
||||
fn display_name() -> &'static str {
|
||||
"Zed's Edit Predictions 2"
|
||||
"Zed's Edit Predictions"
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
fn show_predictions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
@@ -57,17 +58,38 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
true
|
||||
}
|
||||
|
||||
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
|
||||
// TODO [zeta2]
|
||||
DataCollectionState::Unsupported
|
||||
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
|
||||
if let Some(buffer) = &self.singleton_buffer
|
||||
&& let Some(file) = buffer.read(cx).file()
|
||||
{
|
||||
let is_project_open_source =
|
||||
self.store
|
||||
.read(cx)
|
||||
.is_file_open_source(&self.project, file, cx);
|
||||
if self.store.read(cx).data_collection_choice.is_enabled() {
|
||||
DataCollectionState::Enabled {
|
||||
is_project_open_source,
|
||||
}
|
||||
} else {
|
||||
DataCollectionState::Disabled {
|
||||
is_project_open_source,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return DataCollectionState::Disabled {
|
||||
is_project_open_source: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_data_collection(&mut self, _cx: &mut App) {
|
||||
// TODO [zeta2]
|
||||
fn toggle_data_collection(&mut self, cx: &mut App) {
|
||||
self.store.update(cx, |store, cx| {
|
||||
store.toggle_data_collection_choice(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
|
||||
self.zeta.read(cx).usage(cx)
|
||||
self.store.read(cx).usage(cx)
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
@@ -76,16 +98,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
_cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let zeta = self.zeta.read(cx);
|
||||
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
|
||||
zeta.has_sweep_api_token()
|
||||
let store = self.store.read(cx);
|
||||
if store.edit_prediction_model == EditPredictionModel::Sweep {
|
||||
store.has_sweep_api_token()
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn is_refreshing(&self, cx: &App) -> bool {
|
||||
self.zeta.read(cx).is_refreshing(&self.project)
|
||||
self.store.read(cx).is_refreshing(&self.project)
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
@@ -95,24 +117,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
_debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let zeta = self.zeta.read(cx);
|
||||
let store = self.store.read(cx);
|
||||
|
||||
if zeta.user_store.read_with(cx, |user_store, _cx| {
|
||||
if store.user_store.read_with(cx, |user_store, _cx| {
|
||||
user_store.account_too_young() || user_store.has_overdue_invoices()
|
||||
}) {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
|
||||
zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
self.store.update(cx, |store, cx| {
|
||||
store.refresh_context(&self.project, &buffer, cursor_position, cx);
|
||||
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -126,20 +148,20 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
}
|
||||
|
||||
fn accept(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.accept_current_prediction(&self.project, cx);
|
||||
self.store.update(cx, |store, cx| {
|
||||
store.accept_current_prediction(&self.project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn discard(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
|
||||
self.store.update(cx, |store, _cx| {
|
||||
store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
|
||||
});
|
||||
}
|
||||
|
||||
fn did_show(&mut self, cx: &mut Context<Self>) {
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.did_show_current_prediction(&self.project, cx);
|
||||
self.store.update(cx, |store, cx| {
|
||||
store.did_show_current_prediction(&self.project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -148,16 +170,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
buffer: &Entity<language::Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction::EditPrediction> {
|
||||
) -> Option<edit_prediction_types::EditPrediction> {
|
||||
let prediction =
|
||||
self.zeta
|
||||
self.store
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction::EditPrediction::Jump {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
@@ -169,8 +191,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.zeta.update(cx, |zeta, _cx| {
|
||||
zeta.reject_current_prediction(
|
||||
self.store.update(cx, |store, _cx| {
|
||||
store.reject_current_prediction(
|
||||
EditPredictionRejectReason::InterpolatedEmpty,
|
||||
&self.project,
|
||||
);
|
||||
@@ -208,7 +230,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction::EditPrediction::Local {
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
@@ -3,7 +3,7 @@ mod input_excerpt;
|
||||
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, ZedUpdateRequiredError, Zeta,
|
||||
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
|
||||
prediction::{EditPredictionInputs, EditPredictionResult},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -30,23 +30,23 @@ pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
|
||||
pub(crate) const MAX_EVENT_TOKENS: usize = 500;
|
||||
|
||||
pub(crate) fn request_prediction_with_zeta1(
|
||||
zeta: &mut Zeta,
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
cx: &mut Context<Zeta>,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let client = zeta.client.clone();
|
||||
let llm_token = zeta.llm_token.clone();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
|
||||
let can_collect_file = zeta.can_collect_file(project, file, cx);
|
||||
let can_collect_file = store.can_collect_file(project, file, cx);
|
||||
let git_info = if can_collect_file {
|
||||
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
} else {
|
||||
@@ -102,7 +102,7 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
|
||||
let http_client = client.http_client();
|
||||
|
||||
let response = Zeta::send_api_request::<PredictEditsResponse>(
|
||||
let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|
||||
|request| {
|
||||
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
|
||||
predict_edits_url
|
||||
@@ -124,7 +124,7 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: included_events.into(),
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
@@ -155,8 +155,8 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
Err(err) => {
|
||||
if err.is::<ZedUpdateRequiredError>() {
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |zeta, _cx| {
|
||||
zeta.update_required = true;
|
||||
this.update(cx, |ep_store, _cx| {
|
||||
ep_store.update_required = true;
|
||||
})
|
||||
.ok();
|
||||
|
||||
358
crates/edit_prediction/src/zeta2.rs
Normal file
358
crates/edit_prediction/src/zeta2.rs
Normal file
@@ -0,0 +1,358 @@
|
||||
#[cfg(feature = "eval-support")]
|
||||
use crate::EvalCacheEntryKind;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
EditPredictionRequestedDebugEvent, EditPredictionStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, Line};
|
||||
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Entity, Task, prelude::*};
|
||||
use language::{Anchor, BufferSnapshot};
|
||||
use language::{Buffer, Point, ToOffset as _, ToPoint};
|
||||
use project::{Project, ProjectItem as _};
|
||||
use release_channel::AppVersion;
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
pub fn request_prediction_with_zeta2(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
mut included_files: Vec<RelatedFile>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let options = store.options.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let Some((excerpt_path, active_project_path)) = active_snapshot
|
||||
.file()
|
||||
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
|
||||
.zip(active_buffer.read(cx).project_path(cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("No file path for excerpt")));
|
||||
};
|
||||
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let debug_tx = store.debug_tx.clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
|
||||
|
||||
// TODO data collection
|
||||
let can_collect_data = file
|
||||
.as_ref()
|
||||
.map_or(false, |file| store.can_collect_file(project, file, cx));
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
let active_buffer = active_buffer.clone();
|
||||
async move {
|
||||
let cursor_offset = position.to_offset(&active_snapshot);
|
||||
let cursor_point = cursor_offset.to_point(&active_snapshot);
|
||||
|
||||
let before_retrieval = Instant::now();
|
||||
|
||||
let excerpt_options = options.context;
|
||||
|
||||
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
&active_snapshot,
|
||||
&excerpt_options,
|
||||
) else {
|
||||
return Ok((None, None));
|
||||
};
|
||||
|
||||
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
|
||||
..active_snapshot.anchor_before(excerpt.range.end);
|
||||
let related_excerpt = RelatedExcerpt {
|
||||
anchor_range: excerpt_anchor_range.clone(),
|
||||
point_range: Point::new(excerpt.line_range.start.0, 0)
|
||||
..Point::new(excerpt.line_range.end.0, 0),
|
||||
text: active_snapshot.as_rope().slice(excerpt.range),
|
||||
};
|
||||
|
||||
if let Some(buffer_ix) = included_files
|
||||
.iter()
|
||||
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
|
||||
{
|
||||
let file = &mut included_files[buffer_ix];
|
||||
file.excerpts.push(related_excerpt);
|
||||
file.merge_excerpts();
|
||||
let last_ix = included_files.len() - 1;
|
||||
included_files.swap(buffer_ix, last_ix);
|
||||
} else {
|
||||
let active_file = RelatedFile {
|
||||
path: active_project_path,
|
||||
buffer: active_buffer.downgrade(),
|
||||
excerpts: vec![related_excerpt],
|
||||
max_row: active_snapshot.max_point().row,
|
||||
};
|
||||
included_files.push(active_file);
|
||||
}
|
||||
|
||||
let included_files = included_files
|
||||
.iter()
|
||||
.map(|related_file| predict_edits_v3::RelatedFile {
|
||||
path: Arc::from(related_file.path.path.as_std_path()),
|
||||
max_row: Line(related_file.max_row),
|
||||
excerpts: related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| predict_edits_v3::Excerpt {
|
||||
start_line: Line(excerpt.point_range.start.row),
|
||||
text: excerpt.text.to_string().into(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cloud_request = predict_edits_v3::PredictEditsRequest {
|
||||
excerpt_path,
|
||||
excerpt: String::new(),
|
||||
excerpt_line_range: Line(0)..Line(0),
|
||||
excerpt_range: 0..0,
|
||||
cursor_point: predict_edits_v3::Point {
|
||||
line: predict_edits_v3::Line(cursor_point.row),
|
||||
column: cursor_point.column,
|
||||
},
|
||||
related_files: included_files,
|
||||
events,
|
||||
can_collect_data,
|
||||
debug_info: debug_tx.is_some(),
|
||||
prompt_max_bytes: Some(options.max_prompt_bytes),
|
||||
prompt_format: options.prompt_format,
|
||||
excerpt_parent: None,
|
||||
git_info: None,
|
||||
trigger,
|
||||
};
|
||||
|
||||
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
included_files: cloud_request.related_files,
|
||||
events: cloud_request.events,
|
||||
cursor_point: cloud_request.cursor_point,
|
||||
cursor_path: cloud_request.excerpt_path,
|
||||
};
|
||||
|
||||
let retrieval_time = Instant::now() - before_retrieval;
|
||||
|
||||
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionRequested(
|
||||
EditPredictionRequestedDebugEvent {
|
||||
inputs: inputs.clone(),
|
||||
retrieval_time,
|
||||
buffer: active_buffer.downgrade(),
|
||||
local_prompt: match prompt_result.as_ref() {
|
||||
Ok(prompt) => Ok(prompt.clone()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
position,
|
||||
response_rx,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
Some(response_tx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((Err("Request skipped".to_string()), Duration::ZERO))
|
||||
.ok();
|
||||
}
|
||||
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
|
||||
}
|
||||
|
||||
let prompt = prompt_result?;
|
||||
let generation_params =
|
||||
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
|
||||
let request = open_ai::Request {
|
||||
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(prompt),
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.unwrap_or(0.7),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
prompt_cache_key: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
log::trace!("Sending edit prediction request");
|
||||
|
||||
let before_request = Instant::now();
|
||||
let response = EditPredictionStore::send_raw_llm_request(
|
||||
request,
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache,
|
||||
#[cfg(feature = "eval-support")]
|
||||
EvalCacheEntryKind::Prediction,
|
||||
)
|
||||
.await;
|
||||
let received_response_at = Instant::now();
|
||||
let request_time = received_response_at - before_request;
|
||||
|
||||
log::trace!("Got edit prediction response");
|
||||
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((
|
||||
response
|
||||
.as_ref()
|
||||
.map_err(|err| err.to_string())
|
||||
.map(|response| response.0.clone()),
|
||||
request_time,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (res, usage) = response?;
|
||||
let request_id = EditPredictionId(res.id.clone().into());
|
||||
let Some(mut output_text) = text_from_response(res) else {
|
||||
return Ok((Some((request_id, None)), usage));
|
||||
};
|
||||
|
||||
if output_text.contains(CURSOR_MARKER) {
|
||||
log::trace!("Stripping out {CURSOR_MARKER} from response");
|
||||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let get_buffer_from_context = |path: &Path| {
|
||||
if Some(path) == active_file_full_path.as_deref() {
|
||||
Some((
|
||||
&active_snapshot,
|
||||
std::slice::from_ref(&excerpt_anchor_range),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (_, edits) = match options.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
if output_text.contains("--- a/\n+++ b/\nNo edits") {
|
||||
let edits = vec![];
|
||||
(&active_snapshot, edits)
|
||||
} else {
|
||||
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
_ => {
|
||||
bail!("unsupported prompt format {}", options.prompt_format)
|
||||
}
|
||||
};
|
||||
|
||||
anyhow::Ok((
|
||||
Some((
|
||||
request_id,
|
||||
Some((
|
||||
inputs,
|
||||
active_buffer,
|
||||
active_snapshot.clone(),
|
||||
edits,
|
||||
received_response_at,
|
||||
)),
|
||||
)),
|
||||
usage,
|
||||
))
|
||||
}
|
||||
});
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let Some((id, prediction)) =
|
||||
EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
|
||||
prediction
|
||||
else {
|
||||
return Ok(Some(EditPredictionResult {
|
||||
id,
|
||||
prediction: Err(EditPredictionRejectReason::Empty),
|
||||
}));
|
||||
};
|
||||
|
||||
Ok(Some(
|
||||
EditPredictionResult::new(
|
||||
id,
|
||||
&edited_buffer,
|
||||
&edited_buffer_snapshot,
|
||||
edits.into(),
|
||||
buffer_snapshotted_at,
|
||||
received_response_at,
|
||||
inputs,
|
||||
cx,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
|
||||
let choice = res.choices.pop()?;
|
||||
let output_text = match choice.message {
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(content)),
|
||||
..
|
||||
} => content,
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Multipart(mut content)),
|
||||
..
|
||||
} => {
|
||||
if content.is_empty() {
|
||||
log::error!("No output from Baseten completion response");
|
||||
return None;
|
||||
}
|
||||
|
||||
match content.remove(0) {
|
||||
open_ai::MessagePart::Text { text } => text,
|
||||
open_ai::MessagePart::Image { .. } => {
|
||||
log::error!("Expected text, got an image");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
log::error!("Invalid response message: {:?}", choice.message);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
Some(output_text)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "zeta_cli"
|
||||
name = "edit_prediction_cli"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "zeta"
|
||||
name = "ep_cli"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
@@ -19,7 +19,7 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace= true
|
||||
cloud_zeta2_prompt.workspace= true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
@@ -35,9 +35,7 @@ language_models.workspace = true
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
log.workspace = true
|
||||
node_runtime.workspace = true
|
||||
ordered-float.workspace = true
|
||||
paths.workspace = true
|
||||
polars = { version = "0.51", features = ["lazy", "dtype-struct", "parquet"] }
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
pulldown-cmark.workspace = true
|
||||
@@ -48,12 +46,11 @@ serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
shellexpand.workspace = true
|
||||
smol.workspace = true
|
||||
soa-rs = "0.8.1"
|
||||
terminal_view.workspace = true
|
||||
toml.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
zeta = { workspace = true, features = ["eval-support"] }
|
||||
edit_prediction = { workspace = true, features = ["eval-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -6,17 +6,17 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
use zeta::{Zeta, udiff::DiffLine};
|
||||
|
||||
use crate::{
|
||||
EvaluateArguments, PredictionOptions,
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::print_run_data_dir,
|
||||
predict::{PredictionDetails, perform_predict, setup_zeta},
|
||||
predict::{PredictionDetails, perform_predict, setup_store},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -45,7 +45,7 @@ pub async fn run_evaluate(
|
||||
let project = example.setup_project(&app_state, cx).await.unwrap();
|
||||
|
||||
let providers = (0..args.repetitions)
|
||||
.map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
@@ -53,7 +53,7 @@ pub async fn run_evaluate(
|
||||
let tasks = providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, zeta)| {
|
||||
.map(move |(repetition_ix, store)| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
@@ -65,7 +65,7 @@ pub async fn run_evaluate(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
zeta,
|
||||
store,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
@@ -154,7 +154,7 @@ pub async fn run_evaluate_one(
|
||||
example: NamedExample,
|
||||
repetition_ix: Option<u16>,
|
||||
project: Entity<Project>,
|
||||
zeta: Entity<Zeta>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
prediction_options: PredictionOptions,
|
||||
predict: bool,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -162,7 +162,7 @@ pub async fn run_evaluate_one(
|
||||
let predict_result = perform_predict(
|
||||
example.clone(),
|
||||
project,
|
||||
zeta,
|
||||
store,
|
||||
repetition_ix,
|
||||
prediction_options,
|
||||
cx,
|
||||
@@ -14,6 +14,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use collections::HashMap;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
@@ -25,7 +26,6 @@ use project::{Project, ProjectPath};
|
||||
use pulldown_cmark::CowStr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta::udiff::OpenedBuffers;
|
||||
|
||||
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
|
||||
|
||||
@@ -481,7 +481,7 @@ impl NamedExample {
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'_>> {
|
||||
zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod syntax_retrieval_stats;
|
||||
mod util;
|
||||
|
||||
use crate::{
|
||||
@@ -14,27 +13,21 @@ use crate::{
|
||||
headless::ZetaCliAppState,
|
||||
predict::run_predict,
|
||||
source_location::SourceLocation,
|
||||
syntax_retrieval_stats::retrieval_stats,
|
||||
util::{open_buffer, open_buffer_with_language_server},
|
||||
};
|
||||
use ::util::paths::PathStyle;
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction_context::{
|
||||
EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
use gpui::{Application, AsyncApp, Entity, prelude::*};
|
||||
use language::{Bias, Buffer, BufferSnapshot, Point};
|
||||
use metrics::delta_chr_f;
|
||||
use project::{Project, Worktree};
|
||||
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde_json::json;
|
||||
use std::io::{self};
|
||||
use std::time::Duration;
|
||||
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
|
||||
use zeta::ContextMode;
|
||||
use zeta::udiff::DiffLine;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeta")]
|
||||
@@ -48,7 +41,6 @@ struct ZetaCliArgs {
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Command {
|
||||
Context(ContextArgs),
|
||||
ContextStats(ContextStatsArgs),
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
ConvertExample {
|
||||
@@ -63,20 +55,6 @@ enum Command {
|
||||
Clean,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct ContextStatsArgs {
|
||||
#[arg(long)]
|
||||
worktree: PathBuf,
|
||||
#[arg(long)]
|
||||
extension: Option<String>,
|
||||
#[arg(long)]
|
||||
limit: Option<usize>,
|
||||
#[arg(long)]
|
||||
skip: Option<usize>,
|
||||
#[clap(flatten)]
|
||||
zeta2_args: Zeta2Args,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct ContextArgs {
|
||||
#[arg(long)]
|
||||
@@ -97,7 +75,7 @@ struct ContextArgs {
|
||||
enum ContextProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
Syntax,
|
||||
Zeta2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
@@ -204,35 +182,22 @@ enum PredictionProvider {
|
||||
Sweep,
|
||||
}
|
||||
|
||||
fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
|
||||
zeta::ZetaOptions {
|
||||
context: ContextMode::Syntax(EditPredictionContextOptions {
|
||||
max_retrieved_declarations: args.max_retrieved_definitions,
|
||||
use_imports: !args.disable_imports_gathering,
|
||||
excerpt: EditPredictionExcerptOptions {
|
||||
max_bytes: args.max_excerpt_bytes,
|
||||
min_bytes: args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
|
||||
},
|
||||
score: EditPredictionScoreOptions {
|
||||
omit_excerpt_overlaps,
|
||||
},
|
||||
}),
|
||||
max_diagnostic_bytes: args.max_diagnostic_bytes,
|
||||
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
|
||||
edit_prediction::ZetaOptions {
|
||||
context: EditPredictionExcerptOptions {
|
||||
max_bytes: args.max_excerpt_bytes,
|
||||
min_bytes: args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
|
||||
},
|
||||
max_prompt_bytes: args.max_prompt_bytes,
|
||||
prompt_format: args.prompt_format.into(),
|
||||
file_indexing_parallelism: args.file_indexing_parallelism,
|
||||
buffer_change_grouping_interval: Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum PromptFormat {
|
||||
MarkedExcerpt,
|
||||
LabeledSections,
|
||||
OnlySnippets,
|
||||
#[default]
|
||||
NumberedLines,
|
||||
OldTextNewText,
|
||||
Minimal,
|
||||
MinimalQwen,
|
||||
@@ -242,10 +207,7 @@ enum PromptFormat {
|
||||
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
|
||||
fn into(self) -> predict_edits_v3::PromptFormat {
|
||||
match self {
|
||||
Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
|
||||
Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
|
||||
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
|
||||
Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
|
||||
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
|
||||
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
|
||||
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
|
||||
@@ -295,6 +257,7 @@ struct LoadedContext {
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
lsp_open_handle: Option<OpenLspBufferHandle>,
|
||||
}
|
||||
|
||||
async fn load_context(
|
||||
@@ -330,7 +293,7 @@ async fn load_context(
|
||||
.await?;
|
||||
|
||||
let mut ready_languages = HashSet::default();
|
||||
let (_lsp_open_handle, buffer) = if *use_language_server {
|
||||
let (lsp_open_handle, buffer) = if *use_language_server {
|
||||
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
|
||||
project.clone(),
|
||||
worktree.clone(),
|
||||
@@ -377,10 +340,11 @@ async fn load_context(
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
lsp_open_handle,
|
||||
})
|
||||
}
|
||||
|
||||
async fn zeta2_syntax_context(
|
||||
async fn zeta2_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -390,6 +354,7 @@ async fn zeta2_syntax_context(
|
||||
project,
|
||||
buffer,
|
||||
clipped_cursor,
|
||||
lsp_open_handle: _handle,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
@@ -402,34 +367,32 @@ async fn zeta2_syntax_context(
|
||||
.await;
|
||||
let output = cx
|
||||
.update(|cx| {
|
||||
let zeta = cx.new(|cx| {
|
||||
zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let indexing_done_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
|
||||
zeta.register_buffer(&buffer, &project, cx);
|
||||
zeta.wait_for_initial_indexing(&project, cx)
|
||||
store.update(cx, |store, cx| {
|
||||
store.set_options(zeta2_args_to_options(&args.zeta2_args));
|
||||
store.register_buffer(&buffer, &project, cx);
|
||||
});
|
||||
cx.spawn(async move |cx| {
|
||||
indexing_done_task.await?;
|
||||
let request = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
|
||||
zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
|
||||
})?
|
||||
.await?;
|
||||
let updates_rx = store.update(cx, |store, cx| {
|
||||
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &buffer, cursor, cx);
|
||||
store.project_context_updates(&project).unwrap()
|
||||
})?;
|
||||
|
||||
let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
|
||||
updates_rx.recv().await.ok();
|
||||
|
||||
match args.zeta2_args.output_format {
|
||||
OutputFormat::Prompt => anyhow::Ok(prompt_string),
|
||||
OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
|
||||
OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
|
||||
"request": request,
|
||||
"prompt": prompt_string,
|
||||
"section_labels": section_labels,
|
||||
}))?),
|
||||
}
|
||||
let context = store.update(cx, |store, cx| {
|
||||
store.context_for_project(&project, cx).to_vec()
|
||||
})?;
|
||||
|
||||
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
@@ -441,7 +404,7 @@ async fn zeta1_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<zeta::zeta1::GatherContextOutput> {
|
||||
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
|
||||
let LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
@@ -456,7 +419,7 @@ async fn zeta1_context(
|
||||
|
||||
let prompt_for_events = move || (events, 0);
|
||||
cx.update(|cx| {
|
||||
zeta::zeta1::gather_context(
|
||||
edit_prediction::zeta1::gather_context(
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
clipped_cursor,
|
||||
@@ -482,24 +445,10 @@ fn main() {
|
||||
None => {
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
return;
|
||||
} else {
|
||||
panic!("Expected a command");
|
||||
}
|
||||
}
|
||||
Some(Command::ContextStats(arguments)) => {
|
||||
let result = retrieval_stats(
|
||||
arguments.worktree,
|
||||
app_state,
|
||||
arguments.extension,
|
||||
arguments.limit,
|
||||
arguments.skip,
|
||||
zeta2_args_to_options(&arguments.zeta2_args, false),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
println!("{}", result.unwrap());
|
||||
}
|
||||
Some(Command::Context(context_args)) => {
|
||||
let result = match context_args.provider {
|
||||
ContextProvider::Zeta1 => {
|
||||
@@ -507,10 +456,8 @@ fn main() {
|
||||
zeta1_context(context_args, &app_state, cx).await.unwrap();
|
||||
serde_json::to_string_pretty(&context.body).unwrap()
|
||||
}
|
||||
ContextProvider::Syntax => {
|
||||
zeta2_syntax_context(context_args, &app_state, cx)
|
||||
.await
|
||||
.unwrap()
|
||||
ContextProvider::Zeta2 => {
|
||||
zeta2_context(context_args, &app_state, cx).await.unwrap()
|
||||
}
|
||||
};
|
||||
println!("{}", result);
|
||||
@@ -1,5 +1,5 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use zeta::udiff::DiffLine;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
@@ -287,7 +287,7 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use zeta::udiff::DiffLine;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
|
||||
#[test]
|
||||
fn test_delta_chr_f_perfect_match() {
|
||||
@@ -7,6 +7,7 @@ use crate::{
|
||||
use ::serde::Serialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use project::Project;
|
||||
@@ -18,7 +19,6 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
|
||||
|
||||
pub async fn run_predict(
|
||||
args: PredictArguments,
|
||||
@@ -27,9 +27,9 @@ pub async fn run_predict(
|
||||
) {
|
||||
let example = NamedExample::load(args.example_path).unwrap();
|
||||
let project = example.setup_project(app_state, cx).await.unwrap();
|
||||
let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
let result = perform_predict(example, project, zeta, None, args.options, cx)
|
||||
let result = perform_predict(example, project, store, None, args.options, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
result.write(args.format, std::io::stdout()).unwrap();
|
||||
@@ -37,45 +37,50 @@ pub async fn run_predict(
|
||||
print_run_data_dir(true, std::io::stdout().is_terminal());
|
||||
}
|
||||
|
||||
pub fn setup_zeta(
|
||||
pub fn setup_store(
|
||||
provider: PredictionProvider,
|
||||
project: &Entity<Project>,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Zeta>> {
|
||||
let zeta =
|
||||
cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
|
||||
) -> Result<Entity<EditPredictionStore>> {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
store.update(cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
};
|
||||
zeta.set_edit_prediction_model(model);
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
let zeta = zeta.clone();
|
||||
let store = store.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
|
||||
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
|
||||
anyhow::Ok(zeta)
|
||||
anyhow::Ok(store)
|
||||
}
|
||||
|
||||
pub async fn perform_predict(
|
||||
example: NamedExample,
|
||||
project: Entity<Project>,
|
||||
zeta: Entity<Zeta>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
repetition_ix: Option<u16>,
|
||||
options: PredictionOptions,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -108,8 +113,8 @@ pub async fn perform_predict(
|
||||
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
zeta.with_eval_cache(Arc::new(RunCache {
|
||||
store.update(cx, |store, _cx| {
|
||||
store.with_eval_cache(Arc::new(RunCache {
|
||||
example_run_dir: example_run_dir.clone(),
|
||||
cache_mode,
|
||||
}));
|
||||
@@ -121,44 +126,43 @@ pub async fn perform_predict(
|
||||
|
||||
let prompt_format = options.zeta2.prompt_format;
|
||||
|
||||
zeta.update(cx, |zeta, _cx| {
|
||||
let mut options = zeta.options().clone();
|
||||
store.update(cx, |store, _cx| {
|
||||
let mut options = store.options().clone();
|
||||
options.prompt_format = prompt_format.into();
|
||||
zeta.set_options(options);
|
||||
store.set_options(options);
|
||||
})?;
|
||||
|
||||
let mut debug_task = gpui::Task::ready(Ok(()));
|
||||
|
||||
if options.provider == crate::PredictionProvider::Zeta2 {
|
||||
let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
|
||||
|
||||
debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut search_queries_generated_at = None;
|
||||
let mut search_queries_executed_at = None;
|
||||
let mut retrieval_finished_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
|
||||
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
}
|
||||
zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
|
||||
search_queries_generated_at = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
serde_json::to_string_pretty(&info.search_queries).unwrap(),
|
||||
)?;
|
||||
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
|
||||
retrieval_finished_at = Some(info.timestamp);
|
||||
for (key, value) in &info.metadata {
|
||||
if *key == "search_queries" {
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
value.as_bytes(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
|
||||
search_queries_executed_at = Some(info.timestamp);
|
||||
}
|
||||
zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
|
||||
zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
|
||||
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
@@ -194,19 +198,15 @@ pub async fn perform_predict(
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response = zeta::text_from_response(response).unwrap_or_default();
|
||||
let response = edit_prediction::zeta2::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
|
||||
result.planning_search_time =
|
||||
Some(search_queries_generated_at.unwrap() - start_time.unwrap());
|
||||
result.running_search_time = Some(
|
||||
search_queries_executed_at.unwrap()
|
||||
- search_queries_generated_at.unwrap(),
|
||||
);
|
||||
result.retrieval_time =
|
||||
retrieval_finished_at.unwrap() - start_time.unwrap();
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
@@ -218,15 +218,14 @@ pub async fn perform_predict(
|
||||
}
|
||||
});
|
||||
|
||||
zeta.update(cx, |zeta, cx| {
|
||||
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
|
||||
})?
|
||||
.await?;
|
||||
store.update(cx, |store, cx| {
|
||||
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?;
|
||||
}
|
||||
|
||||
let prediction = zeta
|
||||
.update(cx, |zeta, cx| {
|
||||
zeta.request_prediction(
|
||||
let prediction = store
|
||||
.update(cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&project,
|
||||
&cursor_buffer,
|
||||
cursor_anchor,
|
||||
@@ -321,8 +320,7 @@ pub struct PredictionDetails {
|
||||
pub diff: String,
|
||||
pub excerpts: Vec<ActualExcerpt>,
|
||||
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
|
||||
pub planning_search_time: Option<Duration>,
|
||||
pub running_search_time: Option<Duration>,
|
||||
pub retrieval_time: Duration,
|
||||
pub prediction_time: Duration,
|
||||
pub total_time: Duration,
|
||||
pub run_example_dir: PathBuf,
|
||||
@@ -336,8 +334,7 @@ impl PredictionDetails {
|
||||
diff: Default::default(),
|
||||
excerpts: Default::default(),
|
||||
excerpts_text: Default::default(),
|
||||
planning_search_time: Default::default(),
|
||||
running_search_time: Default::default(),
|
||||
retrieval_time: Default::default(),
|
||||
prediction_time: Default::default(),
|
||||
total_time: Default::default(),
|
||||
run_example_dir,
|
||||
@@ -357,28 +354,20 @@ impl PredictionDetails {
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
|
||||
|
||||
format!(
|
||||
"## Excerpts\n\n\
|
||||
{}\n\n\
|
||||
## Prediction\n\n\
|
||||
{}\n\n\
|
||||
## Time\n\n\
|
||||
Planning searches: {}ms\n\
|
||||
Running searches: {}ms\n\
|
||||
Making Prediction: {}ms\n\n\
|
||||
-------------------\n\n\
|
||||
Total: {}ms\n\
|
||||
Inference: {}ms ({:.2}%)\n",
|
||||
Retrieval: {}ms\n\
|
||||
Prediction: {}ms\n\n\
|
||||
Total: {}ms\n",
|
||||
self.excerpts_text,
|
||||
self.diff,
|
||||
self.planning_search_time.unwrap_or_default().as_millis(),
|
||||
self.running_search_time.unwrap_or_default().as_millis(),
|
||||
self.retrieval_time.as_millis(),
|
||||
self.prediction_time.as_millis(),
|
||||
self.total_time.as_millis(),
|
||||
inference_time.as_millis(),
|
||||
(inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,8 @@ use anyhow::{Result, anyhow};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language::{Buffer, LanguageId, LanguageServerId, ParseStatus};
|
||||
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
|
||||
use project::lsp_store::OpenLspBufferHandle;
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
@@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server(
|
||||
path: Arc<RelPath>,
|
||||
ready_languages: &mut HashSet<LanguageId>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
|
||||
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
|
||||
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
|
||||
|
||||
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
|
||||
@@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server(
|
||||
)
|
||||
})?;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(path.as_std_path())
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
anyhow::bail!(error);
|
||||
}
|
||||
|
||||
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})?
|
||||
@@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server(
|
||||
return Err(anyhow!("No language for {}", path.display(path_style)));
|
||||
};
|
||||
|
||||
let log_prefix = path.display(path_style);
|
||||
let log_prefix = format!("{} | ", path.display(path_style));
|
||||
if !ready_languages.contains(&language_id) {
|
||||
wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
|
||||
ready_languages.insert(language_id);
|
||||
}
|
||||
|
||||
@@ -95,7 +107,7 @@ pub fn wait_for_lang_server(
|
||||
log_prefix: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
println!("{}⏵ Waiting for language server", log_prefix);
|
||||
eprintln!("{}⏵ Waiting for language server", log_prefix);
|
||||
|
||||
let (mut tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
@@ -137,7 +149,7 @@ pub fn wait_for_lang_server(
|
||||
..
|
||||
} = event
|
||||
{
|
||||
println!("{}⟲ {message}", log_prefix)
|
||||
eprintln!("{}⟲ {message}", log_prefix)
|
||||
}
|
||||
}
|
||||
}),
|
||||
@@ -162,7 +174,7 @@ pub fn wait_for_lang_server(
|
||||
cx.spawn(async move |cx| {
|
||||
if !has_lang_server {
|
||||
// some buffers never have a language server, so this aborts quickly in that case.
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(5));
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(500));
|
||||
futures::select! {
|
||||
_ = added_rx.next() => {},
|
||||
_ = timeout.fuse() => {
|
||||
@@ -173,7 +185,7 @@ pub fn wait_for_lang_server(
|
||||
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
|
||||
let result = futures::select! {
|
||||
_ = rx.next() => {
|
||||
println!("{}⚑ Language server idle", log_prefix);
|
||||
eprintln!("{}⚑ Language server idle", log_prefix);
|
||||
anyhow::Ok(())
|
||||
},
|
||||
_ = timeout.fuse() => {
|
||||
@@ -12,41 +12,32 @@ workspace = true
|
||||
path = "src/edit_prediction_context.rs"
|
||||
|
||||
[dependencies]
|
||||
parking_lot.workspace = true
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
hashbrown.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
ordered-float.workspace = true
|
||||
postage.workspace = true
|
||||
lsp.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
log.workspace = true
|
||||
serde.workspace = true
|
||||
slotmap.workspace = true
|
||||
strum.workspace = true
|
||||
text.workspace = true
|
||||
smallvec.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clap.workspace = true
|
||||
env_logger.workspace = true
|
||||
indoc.workspace = true
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = {workspace= true, features = ["test-support"]}
|
||||
serde_json.workspace = true
|
||||
settings = {workspace= true, features = ["test-support"]}
|
||||
text = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-c.workspace = true
|
||||
tree-sitter-cpp.workspace = true
|
||||
tree-sitter-go.workspace = true
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
161
crates/edit_prediction_context/src/assemble_excerpts.rs
Normal file
161
crates/edit_prediction_context/src/assemble_excerpts.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use crate::RelatedExcerpt;
|
||||
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::ops::Range;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
|
||||
#[cfg(test)]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24;
|
||||
|
||||
pub fn assemble_excerpts(
|
||||
buffer: &BufferSnapshot,
|
||||
mut input_ranges: Vec<Range<Point>>,
|
||||
) -> Vec<RelatedExcerpt> {
|
||||
merge_ranges(&mut input_ranges);
|
||||
|
||||
let mut outline_ranges = Vec::new();
|
||||
let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
|
||||
let mut outline_ix = 0;
|
||||
for input_range in &mut input_ranges {
|
||||
*input_range = clip_range_to_lines(input_range, false, buffer);
|
||||
|
||||
while let Some(outline_item) = outline_items.get(outline_ix) {
|
||||
let item_range = clip_range_to_lines(&outline_item.range, false, buffer);
|
||||
|
||||
if item_range.start > input_range.start {
|
||||
break;
|
||||
}
|
||||
|
||||
if item_range.end > input_range.start {
|
||||
let body_range = outline_item
|
||||
.body_range(buffer)
|
||||
.map(|body| clip_range_to_lines(&body, true, buffer))
|
||||
.filter(|body_range| {
|
||||
body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE
|
||||
});
|
||||
|
||||
add_outline_item(
|
||||
item_range.clone(),
|
||||
body_range.clone(),
|
||||
buffer,
|
||||
&mut outline_ranges,
|
||||
);
|
||||
|
||||
if let Some(body_range) = body_range
|
||||
&& input_range.start < body_range.start
|
||||
{
|
||||
let mut child_outline_ix = outline_ix + 1;
|
||||
while let Some(next_outline_item) = outline_items.get(child_outline_ix) {
|
||||
if next_outline_item.range.end > body_range.end {
|
||||
break;
|
||||
}
|
||||
if next_outline_item.depth == outline_item.depth + 1 {
|
||||
let next_item_range =
|
||||
clip_range_to_lines(&next_outline_item.range, false, buffer);
|
||||
|
||||
add_outline_item(
|
||||
next_item_range,
|
||||
next_outline_item
|
||||
.body_range(buffer)
|
||||
.map(|body| clip_range_to_lines(&body, true, buffer)),
|
||||
buffer,
|
||||
&mut outline_ranges,
|
||||
);
|
||||
}
|
||||
child_outline_ix += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outline_ix += 1;
|
||||
}
|
||||
}
|
||||
|
||||
input_ranges.extend_from_slice(&outline_ranges);
|
||||
merge_ranges(&mut input_ranges);
|
||||
|
||||
input_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let offset_range = range.to_offset(buffer);
|
||||
RelatedExcerpt {
|
||||
point_range: range,
|
||||
anchor_range: buffer.anchor_before(offset_range.start)
|
||||
..buffer.anchor_after(offset_range.end),
|
||||
text: buffer.as_rope().slice(offset_range),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn clip_range_to_lines(
|
||||
range: &Range<Point>,
|
||||
inward: bool,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Range<Point> {
|
||||
let mut range = range.clone();
|
||||
if inward {
|
||||
if range.start.column > 0 {
|
||||
range.start.column = buffer.line_len(range.start.row);
|
||||
}
|
||||
range.end.column = 0;
|
||||
} else {
|
||||
range.start.column = 0;
|
||||
if range.end.column > 0 {
|
||||
range.end.column = buffer.line_len(range.end.row);
|
||||
}
|
||||
}
|
||||
range
|
||||
}
|
||||
|
||||
fn add_outline_item(
|
||||
mut item_range: Range<Point>,
|
||||
body_range: Option<Range<Point>>,
|
||||
buffer: &BufferSnapshot,
|
||||
outline_ranges: &mut Vec<Range<Point>>,
|
||||
) {
|
||||
if let Some(mut body_range) = body_range {
|
||||
if body_range.start.column > 0 {
|
||||
body_range.start.column = buffer.line_len(body_range.start.row);
|
||||
}
|
||||
body_range.end.column = 0;
|
||||
|
||||
let head_range = item_range.start..body_range.start;
|
||||
if head_range.start < head_range.end {
|
||||
outline_ranges.push(head_range);
|
||||
}
|
||||
|
||||
let tail_range = body_range.end..item_range.end;
|
||||
if tail_range.start < tail_range.end {
|
||||
outline_ranges.push(tail_range);
|
||||
}
|
||||
} else {
|
||||
item_range.start.column = 0;
|
||||
item_range.end.column = buffer.line_len(item_range.end.row);
|
||||
outline_ranges.push(item_range);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_ranges(ranges: &mut Vec<Range<Point>>) {
|
||||
ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
|
||||
|
||||
let mut index = 1;
|
||||
while index < ranges.len() {
|
||||
let mut prev_range_end = ranges[index - 1].end;
|
||||
if prev_range_end.column > 0 {
|
||||
prev_range_end += Point::new(1, 0);
|
||||
}
|
||||
|
||||
if (prev_range_end + Point::new(1, 0))
|
||||
.cmp(&ranges[index].start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = ranges.remove(index);
|
||||
if removed.end.cmp(&ranges[index - 1].end).is_gt() {
|
||||
ranges[index - 1].end = removed.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
use cloud_llm_client::predict_edits_v3::{self, Line};
|
||||
use language::{Language, LanguageId};
|
||||
use project::ProjectEntryId;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, path::Path};
|
||||
use text::{Bias, BufferId, Rope};
|
||||
use util::paths::{path_ends_with, strip_path_suffix};
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
use crate::outline::OutlineDeclaration;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Identifier {
|
||||
pub name: Arc<str>,
|
||||
pub language_id: LanguageId,
|
||||
}
|
||||
|
||||
slotmap::new_key_type! {
|
||||
pub struct DeclarationId;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Declaration {
|
||||
File {
|
||||
project_entry_id: ProjectEntryId,
|
||||
declaration: FileDeclaration,
|
||||
cached_path: CachedDeclarationPath,
|
||||
},
|
||||
Buffer {
|
||||
project_entry_id: ProjectEntryId,
|
||||
buffer_id: BufferId,
|
||||
rope: Rope,
|
||||
declaration: BufferDeclaration,
|
||||
cached_path: CachedDeclarationPath,
|
||||
},
|
||||
}
|
||||
|
||||
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
|
||||
|
||||
impl Declaration {
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => &declaration.identifier,
|
||||
Declaration::Buffer { declaration, .. } => &declaration.identifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parent(&self) -> Option<DeclarationId> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.parent,
|
||||
Declaration::Buffer { declaration, .. } => declaration.parent,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
|
||||
match self {
|
||||
Declaration::File { .. } => None,
|
||||
Declaration::Buffer { declaration, .. } => Some(declaration),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_file(&self) -> Option<&FileDeclaration> {
|
||||
match self {
|
||||
Declaration::Buffer { .. } => None,
|
||||
Declaration::File { declaration, .. } => Some(declaration),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project_entry_id(&self) -> ProjectEntryId {
|
||||
match self {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => *project_entry_id,
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => *project_entry_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cached_path(&self) -> &CachedDeclarationPath {
|
||||
match self {
|
||||
Declaration::File { cached_path, .. } => cached_path,
|
||||
Declaration::Buffer { cached_path, .. } => cached_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_range(&self) -> Range<usize> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.item_range.clone(),
|
||||
Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_line_range(&self) -> Range<Line> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
|
||||
Declaration::Buffer {
|
||||
declaration, rope, ..
|
||||
} => {
|
||||
Line(rope.offset_to_point(declaration.item_range.start).row)
|
||||
..Line(rope.offset_to_point(declaration.item_range.end).row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text.as_ref().into(),
|
||||
declaration.text_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.item_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.item_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text[self.signature_range_in_item_text()].into(),
|
||||
declaration.signature_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.signature_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.signature_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_range(&self) -> Range<usize> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.signature_range.clone(),
|
||||
Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_line_range(&self) -> Range<Line> {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
|
||||
Declaration::Buffer {
|
||||
declaration, rope, ..
|
||||
} => {
|
||||
Line(rope.offset_to_point(declaration.signature_range.start).row)
|
||||
..Line(rope.offset_to_point(declaration.signature_range.end).row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_range_in_item_text(&self) -> Range<usize> {
|
||||
let signature_range = self.signature_range();
|
||||
let item_range = self.item_range();
|
||||
signature_range.start.saturating_sub(item_range.start)
|
||||
..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_range_to_line_boundaries_and_truncate(
|
||||
range: &Range<usize>,
|
||||
limit: usize,
|
||||
rope: &Rope,
|
||||
) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
|
||||
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
|
||||
point_range.start.column = 0;
|
||||
point_range.end.row += 1;
|
||||
point_range.end.column = 0;
|
||||
|
||||
let mut item_range =
|
||||
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
|
||||
let is_truncated = item_range.len() > limit;
|
||||
if is_truncated {
|
||||
item_range.end = item_range.start + limit;
|
||||
}
|
||||
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
|
||||
|
||||
let line_range =
|
||||
predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
|
||||
(item_range, line_range, is_truncated)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
/// offset range of the declaration in the file, expanded to line boundaries and truncated
|
||||
pub item_range: Range<usize>,
|
||||
/// line range of the declaration in the file, potentially truncated
|
||||
pub item_line_range: Range<predict_edits_v3::Line>,
|
||||
/// text of `item_range`
|
||||
pub text: Arc<str>,
|
||||
/// whether `text` was truncated
|
||||
pub text_is_truncated: bool,
|
||||
/// offset range of the signature in the file, expanded to line boundaries and truncated
|
||||
pub signature_range: Range<usize>,
|
||||
/// line range of the signature in the file, truncated
|
||||
pub signature_line_range: Range<Line>,
|
||||
/// whether `signature` was truncated
|
||||
pub signature_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl FileDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
|
||||
let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
if signature_range_in_file.start < item_range_in_file.start {
|
||||
signature_range_in_file.start = item_range_in_file.start;
|
||||
signature_is_truncated = true;
|
||||
}
|
||||
if signature_range_in_file.end > item_range_in_file.end {
|
||||
signature_range_in_file.end = item_range_in_file.end;
|
||||
signature_is_truncated = true;
|
||||
}
|
||||
|
||||
FileDeclaration {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
signature_range: signature_range_in_file,
|
||||
signature_line_range,
|
||||
signature_is_truncated,
|
||||
text: rope
|
||||
.chunks_in_range(item_range_in_file.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
text_is_truncated,
|
||||
item_range: item_range_in_file,
|
||||
item_line_range: item_line_range_in_file,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BufferDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub item_range_is_truncated: bool,
|
||||
pub signature_range: Range<usize>,
|
||||
pub signature_range_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl BufferDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
|
||||
let (item_range, _item_line_range, item_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
let (signature_range, _signature_line_range, signature_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
Self {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range,
|
||||
item_range_is_truncated,
|
||||
signature_range,
|
||||
signature_range_is_truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedDeclarationPath {
|
||||
pub worktree_abs_path: Arc<Path>,
|
||||
pub rel_path: Arc<RelPath>,
|
||||
/// The relative path of the file, possibly stripped according to `import_path_strip_regex`.
|
||||
pub rel_path_after_regex_stripping: Arc<RelPath>,
|
||||
}
|
||||
|
||||
impl CachedDeclarationPath {
|
||||
pub fn new(
|
||||
worktree_abs_path: Arc<Path>,
|
||||
path: &Arc<RelPath>,
|
||||
language: Option<&Arc<Language>>,
|
||||
) -> Self {
|
||||
let rel_path = path.clone();
|
||||
let rel_path_after_regex_stripping = if let Some(language) = language
|
||||
&& let Some(strip_regex) = language.config().import_path_strip_regex.as_ref()
|
||||
&& let Ok(stripped) = RelPath::unix(&Path::new(
|
||||
strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(),
|
||||
)) {
|
||||
Arc::from(stripped)
|
||||
} else {
|
||||
rel_path.clone()
|
||||
};
|
||||
CachedDeclarationPath {
|
||||
worktree_abs_path,
|
||||
rel_path,
|
||||
rel_path_after_regex_stripping,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self {
|
||||
let rel_path: Arc<RelPath> = util::rel_path::rel_path(rel_path).into();
|
||||
CachedDeclarationPath {
|
||||
worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(),
|
||||
rel_path_after_regex_stripping: rel_path.clone(),
|
||||
rel_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ends_with_posix_path(&self, path: &Path) -> bool {
|
||||
if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() {
|
||||
path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path)
|
||||
} else {
|
||||
if let Some(remaining) =
|
||||
strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path())
|
||||
{
|
||||
path_ends_with(&self.worktree_abs_path, remaining)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn equals_absolute_path(&self, path: &Path) -> bool {
|
||||
if let Some(remaining) =
|
||||
strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path())
|
||||
{
|
||||
self.worktree_abs_path.as_ref() == remaining
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,539 +0,0 @@
|
||||
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
|
||||
use collections::HashMap;
|
||||
use language::BufferSnapshot;
|
||||
use ordered_float::OrderedFloat;
|
||||
use project::ProjectEntryId;
|
||||
use serde::Serialize;
|
||||
use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
|
||||
use strum::EnumIter;
|
||||
use text::{Point, ToPoint};
|
||||
use util::RangeExt as _;
|
||||
|
||||
use crate::{
|
||||
CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
|
||||
imports::{Import, Imports, Module},
|
||||
reference::{Reference, ReferenceRegion},
|
||||
syntax_index::SyntaxIndexState,
|
||||
text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
};
|
||||
|
||||
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct EditPredictionScoreOptions {
|
||||
pub omit_excerpt_overlaps: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ScoredDeclaration {
|
||||
/// identifier used by the local reference
|
||||
pub identifier: Identifier,
|
||||
pub declaration: Declaration,
|
||||
pub components: DeclarationScoreComponents,
|
||||
}
|
||||
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum DeclarationStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Default)]
|
||||
pub struct DeclarationScores {
|
||||
pub signature: f32,
|
||||
pub declaration: f32,
|
||||
pub retrieval: f32,
|
||||
}
|
||||
|
||||
impl ScoredDeclaration {
|
||||
/// Returns the score for this declaration with the specified style.
|
||||
pub fn score(&self, style: DeclarationStyle) -> f32 {
|
||||
// TODO: handle truncation
|
||||
|
||||
// Score related to how likely this is the correct declaration, range 0 to 1
|
||||
let retrieval = self.retrieval_score();
|
||||
|
||||
// Score related to the distance between the reference and cursor, range 0 to 1
|
||||
let distance_score = if self.components.is_referenced_nearby {
|
||||
1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
|
||||
} else {
|
||||
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
|
||||
0.5
|
||||
};
|
||||
|
||||
// For now instead of linear combination, the scores are just multiplied together.
|
||||
let combined_score = 10.0 * retrieval * distance_score;
|
||||
|
||||
match style {
|
||||
DeclarationStyle::Signature => {
|
||||
combined_score * self.components.excerpt_vs_signature_weighted_overlap
|
||||
}
|
||||
DeclarationStyle::Declaration => {
|
||||
2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retrieval_score(&self) -> f32 {
|
||||
let mut score = if self.components.is_same_file {
|
||||
10.0 / self.components.same_file_declaration_count as f32
|
||||
} else if self.components.path_import_match_count > 0 {
|
||||
3.0
|
||||
} else if self.components.wildcard_path_import_match_count > 0 {
|
||||
1.0
|
||||
} else if self.components.normalized_import_similarity > 0.0 {
|
||||
self.components.normalized_import_similarity
|
||||
} else if self.components.normalized_wildcard_import_similarity > 0.0 {
|
||||
0.5 * self.components.normalized_wildcard_import_similarity
|
||||
} else {
|
||||
1.0 / self.components.declaration_count as f32
|
||||
};
|
||||
score *= 1. + self.components.included_by_others as f32 / 2.;
|
||||
score *= 1. + self.components.includes_others as f32 / 4.;
|
||||
score
|
||||
}
|
||||
|
||||
pub fn size(&self, style: DeclarationStyle) -> usize {
|
||||
match &self.declaration {
|
||||
Declaration::File { declaration, .. } => match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.text.len(),
|
||||
},
|
||||
Declaration::Buffer { declaration, .. } => match style {
|
||||
DeclarationStyle::Signature => declaration.signature_range.len(),
|
||||
DeclarationStyle::Declaration => declaration.item_range.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn score_density(&self, style: DeclarationStyle) -> f32 {
|
||||
self.score(style) / self.size(style) as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scored_declarations(
|
||||
options: &EditPredictionScoreOptions,
|
||||
index: &SyntaxIndexState,
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
imports: &Imports,
|
||||
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
|
||||
cursor_offset: usize,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Vec<ScoredDeclaration> {
|
||||
let cursor_point = cursor_offset.to_point(¤t_buffer);
|
||||
|
||||
let mut wildcard_import_occurrences = Vec::new();
|
||||
let mut wildcard_import_paths = Vec::new();
|
||||
for wildcard_import in imports.wildcard_modules.iter() {
|
||||
match wildcard_import {
|
||||
Module::Namespace(namespace) => {
|
||||
wildcard_import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => wildcard_import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
wildcard_import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut scored_declarations = Vec::new();
|
||||
let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
|
||||
HashMap::default();
|
||||
for (identifier, references) in identifier_to_references {
|
||||
let mut import_occurrences = Vec::new();
|
||||
let mut import_paths = Vec::new();
|
||||
let mut found_external_identifier: Option<&Identifier> = None;
|
||||
|
||||
if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
|
||||
// only use alias when it's the only import, could be generalized if some language
|
||||
// has overlapping aliases
|
||||
//
|
||||
// TODO: when an aliased declaration is included in the prompt, should include the
|
||||
// aliasing in the prompt.
|
||||
//
|
||||
// TODO: For SourceFuzzy consider having componentwise comparison that pays
|
||||
// attention to ordering.
|
||||
if let [
|
||||
Import::Alias {
|
||||
module,
|
||||
external_identifier,
|
||||
},
|
||||
] = imports.as_slice()
|
||||
{
|
||||
match module {
|
||||
Module::Namespace(namespace) => {
|
||||
import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
}
|
||||
found_external_identifier = Some(&external_identifier);
|
||||
} else {
|
||||
for import in imports {
|
||||
match import {
|
||||
Import::Direct { module } => match module {
|
||||
Module::Namespace(namespace) => {
|
||||
import_occurrences.push(namespace.occurrences())
|
||||
}
|
||||
Module::SourceExact(path) => import_paths.push(path),
|
||||
Module::SourceFuzzy(path) => {
|
||||
import_occurrences.push(Occurrences::from_path(&path))
|
||||
}
|
||||
},
|
||||
Import::Alias { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
|
||||
// TODO: update this to be able to return more declarations? Especially if there is the
|
||||
// ability to quickly filter a large list (based on imports)
|
||||
let identifier_declarations = index
|
||||
.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
|
||||
let declaration_count = identifier_declarations.len();
|
||||
|
||||
if declaration_count == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: option to filter out other candidates when same file / import match
|
||||
let mut checked_declarations = Vec::with_capacity(declaration_count);
|
||||
for (declaration_id, declaration) in identifier_declarations {
|
||||
match declaration {
|
||||
Declaration::Buffer {
|
||||
buffer_id,
|
||||
declaration: buffer_declaration,
|
||||
..
|
||||
} => {
|
||||
if buffer_id == ¤t_buffer.remote_id() {
|
||||
let already_included_in_prompt =
|
||||
range_intersection(&buffer_declaration.item_range, &excerpt.range)
|
||||
.is_some()
|
||||
|| excerpt
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
|
||||
if !options.omit_excerpt_overlaps || !already_included_in_prompt {
|
||||
let declaration_line = buffer_declaration
|
||||
.item_range
|
||||
.start
|
||||
.to_point(current_buffer)
|
||||
.row;
|
||||
let declaration_line_distance =
|
||||
(cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
|
||||
checked_declarations.push(CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance: Some(declaration_line_distance),
|
||||
path_import_match_count: 0,
|
||||
wildcard_path_import_match_count: 0,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
Declaration::File { .. } => {}
|
||||
}
|
||||
let declaration_path = declaration.cached_path();
|
||||
let path_import_match_count = import_paths
|
||||
.iter()
|
||||
.filter(|import_path| {
|
||||
declaration_path_matches_import(&declaration_path, import_path)
|
||||
})
|
||||
.count();
|
||||
let wildcard_path_import_match_count = wildcard_import_paths
|
||||
.iter()
|
||||
.filter(|import_path| {
|
||||
declaration_path_matches_import(&declaration_path, import_path)
|
||||
})
|
||||
.count();
|
||||
checked_declarations.push(CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance: None,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
});
|
||||
}
|
||||
|
||||
let mut max_import_similarity = 0.0;
|
||||
let mut max_wildcard_import_similarity = 0.0;
|
||||
|
||||
let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
|
||||
for checked_declaration in checked_declarations {
|
||||
let same_file_declaration_count =
|
||||
index.file_declaration_count(checked_declaration.declaration);
|
||||
|
||||
let declaration = score_declaration(
|
||||
&identifier,
|
||||
&references,
|
||||
checked_declaration,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
&import_occurrences,
|
||||
&wildcard_import_occurrences,
|
||||
cursor_point,
|
||||
current_buffer,
|
||||
);
|
||||
|
||||
if declaration.components.import_similarity > max_import_similarity {
|
||||
max_import_similarity = declaration.components.import_similarity;
|
||||
}
|
||||
|
||||
if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
|
||||
max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
|
||||
}
|
||||
|
||||
project_entry_id_to_outline_ranges
|
||||
.entry(declaration.declaration.project_entry_id())
|
||||
.or_default()
|
||||
.push(declaration.declaration.item_range());
|
||||
scored_declarations_for_identifier.push(declaration);
|
||||
}
|
||||
|
||||
if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
|
||||
for declaration in scored_declarations_for_identifier.iter_mut() {
|
||||
if max_import_similarity > 0.0 {
|
||||
declaration.components.max_import_similarity = max_import_similarity;
|
||||
declaration.components.normalized_import_similarity =
|
||||
declaration.components.import_similarity / max_import_similarity;
|
||||
}
|
||||
if max_wildcard_import_similarity > 0.0 {
|
||||
declaration.components.normalized_wildcard_import_similarity =
|
||||
declaration.components.wildcard_import_similarity
|
||||
/ max_wildcard_import_similarity;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scored_declarations.extend(scored_declarations_for_identifier);
|
||||
}
|
||||
|
||||
// TODO: Inform this via import / retrieval scores of outline items
|
||||
// TODO: Consider using a sweepline
|
||||
for scored_declaration in scored_declarations.iter_mut() {
|
||||
let project_entry_id = scored_declaration.declaration.project_entry_id();
|
||||
let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
|
||||
continue;
|
||||
};
|
||||
for range in ranges {
|
||||
if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
|
||||
scored_declaration.components.included_by_others += 1
|
||||
} else if scored_declaration
|
||||
.declaration
|
||||
.item_range()
|
||||
.contains_inclusive(range)
|
||||
{
|
||||
scored_declaration.components.includes_others += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
scored_declarations.sort_unstable_by_key(|declaration| {
|
||||
Reverse(OrderedFloat(
|
||||
declaration.score(DeclarationStyle::Declaration),
|
||||
))
|
||||
});
|
||||
|
||||
scored_declarations
|
||||
}
|
||||
|
||||
struct CheckedDeclaration<'a> {
|
||||
declaration: &'a Declaration,
|
||||
same_file_line_distance: Option<u32>,
|
||||
path_import_match_count: usize,
|
||||
wildcard_path_import_match_count: usize,
|
||||
}
|
||||
|
||||
fn declaration_path_matches_import(
|
||||
declaration_path: &CachedDeclarationPath,
|
||||
import_path: &Arc<Path>,
|
||||
) -> bool {
|
||||
if import_path.is_absolute() {
|
||||
declaration_path.equals_absolute_path(import_path)
|
||||
} else {
|
||||
declaration_path.ends_with_posix_path(import_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
|
||||
let start = a.start.clone().max(b.start.clone());
|
||||
let end = a.end.clone().min(b.end.clone());
|
||||
if start < end {
|
||||
Some(Range { start, end })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn score_declaration(
|
||||
identifier: &Identifier,
|
||||
references: &[Reference],
|
||||
checked_declaration: CheckedDeclaration,
|
||||
same_file_declaration_count: usize,
|
||||
declaration_count: usize,
|
||||
excerpt_occurrences: &Occurrences,
|
||||
adjacent_occurrences: &Occurrences,
|
||||
import_occurrences: &[Occurrences],
|
||||
wildcard_import_occurrences: &[Occurrences],
|
||||
cursor: Point,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> ScoredDeclaration {
|
||||
let CheckedDeclaration {
|
||||
declaration,
|
||||
same_file_line_distance,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
} = checked_declaration;
|
||||
|
||||
let is_referenced_nearby = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Nearby);
|
||||
let is_referenced_in_breadcrumb = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Breadcrumb);
|
||||
let reference_count = references.len();
|
||||
let reference_line_distance = references
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let reference_line = r.range.start.to_point(current_buffer).row as i32;
|
||||
(cursor.row as i32 - reference_line).unsigned_abs()
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let is_same_file = same_file_line_distance.is_some();
|
||||
let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
|
||||
|
||||
let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
|
||||
let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_jaccard =
|
||||
jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_jaccard =
|
||||
jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_jaccard =
|
||||
jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
let excerpt_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
|
||||
let excerpt_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
|
||||
let adjacent_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
|
||||
|
||||
let mut import_similarity = 0f32;
|
||||
let mut wildcard_import_similarity = 0f32;
|
||||
if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
|
||||
let cached_path = declaration.cached_path();
|
||||
let path_occurrences = Occurrences::from_worktree_path(
|
||||
cached_path
|
||||
.worktree_abs_path
|
||||
.file_name()
|
||||
.map(|f| f.to_string_lossy()),
|
||||
&cached_path.rel_path,
|
||||
);
|
||||
import_similarity = import_occurrences
|
||||
.iter()
|
||||
.map(|namespace_occurrences| {
|
||||
OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
|
||||
})
|
||||
.max()
|
||||
.map(|similarity| similarity.into_inner())
|
||||
.unwrap_or_default();
|
||||
|
||||
// TODO: Consider something other than max
|
||||
wildcard_import_similarity = wildcard_import_occurrences
|
||||
.iter()
|
||||
.map(|namespace_occurrences| {
|
||||
OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
|
||||
})
|
||||
.max()
|
||||
.map(|similarity| similarity.into_inner())
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
// TODO: Consider adding declaration_file_count
|
||||
let score_components = DeclarationScoreComponents {
|
||||
is_same_file,
|
||||
is_referenced_nearby,
|
||||
is_referenced_in_breadcrumb,
|
||||
reference_line_distance,
|
||||
declaration_line_distance,
|
||||
reference_count,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
excerpt_vs_item_jaccard,
|
||||
excerpt_vs_signature_jaccard,
|
||||
adjacent_vs_item_jaccard,
|
||||
adjacent_vs_signature_jaccard,
|
||||
excerpt_vs_item_weighted_overlap,
|
||||
excerpt_vs_signature_weighted_overlap,
|
||||
adjacent_vs_item_weighted_overlap,
|
||||
adjacent_vs_signature_weighted_overlap,
|
||||
path_import_match_count,
|
||||
wildcard_path_import_match_count,
|
||||
import_similarity,
|
||||
max_import_similarity: 0.0,
|
||||
normalized_import_similarity: 0.0,
|
||||
wildcard_import_similarity,
|
||||
normalized_wildcard_import_similarity: 0.0,
|
||||
included_by_others: 0,
|
||||
includes_others: 0,
|
||||
};
|
||||
|
||||
ScoredDeclaration {
|
||||
identifier: identifier.clone(),
|
||||
declaration: declaration.clone(),
|
||||
components: score_components,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_declaration_path_matches() {
|
||||
let declaration_path =
|
||||
CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("user/project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("/home/user/project/src/maths.ts").into()
|
||||
));
|
||||
|
||||
assert!(!declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("other.ts").into()
|
||||
));
|
||||
|
||||
assert!(!declaration_path_matches_import(
|
||||
&declaration_path,
|
||||
&Path::new("/home/user/project/src/other.ts").into()
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -1,335 +1,490 @@
|
||||
mod declaration;
|
||||
mod declaration_scoring;
|
||||
mod excerpt;
|
||||
mod imports;
|
||||
mod outline;
|
||||
mod reference;
|
||||
mod syntax_index;
|
||||
pub mod text_similarity;
|
||||
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use crate::assemble_excerpts::assemble_excerpts;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use language::BufferSnapshot;
|
||||
use text::{Point, ToOffset as _};
|
||||
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
|
||||
use project::{LocationLink, Project, ProjectPath};
|
||||
use serde::{Serialize, Serializer};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
ops::Range,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::{RangeExt as _, ResultExt};
|
||||
|
||||
pub use declaration::*;
|
||||
pub use declaration_scoring::*;
|
||||
pub use excerpt::*;
|
||||
pub use imports::*;
|
||||
pub use reference::*;
|
||||
pub use syntax_index::*;
|
||||
mod assemble_excerpts;
|
||||
#[cfg(test)]
|
||||
mod edit_prediction_context_tests;
|
||||
mod excerpt;
|
||||
#[cfg(test)]
|
||||
mod fake_definition_lsp;
|
||||
|
||||
pub use predict_edits_v3::Line;
|
||||
pub use cloud_llm_client::predict_edits_v3::Line;
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct EditPredictionContextOptions {
|
||||
pub use_imports: bool,
|
||||
pub excerpt: EditPredictionExcerptOptions,
|
||||
pub score: EditPredictionScoreOptions,
|
||||
pub max_retrieved_declarations: u8,
|
||||
const IDENTIFIER_LINE_COUNT: u32 = 3;
|
||||
|
||||
pub struct RelatedExcerptStore {
|
||||
project: WeakEntity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
cache: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
|
||||
identifier_line_count: u32,
|
||||
}
|
||||
|
||||
pub enum RelatedExcerptStoreEvent {
|
||||
StartedRefresh,
|
||||
FinishedRefresh {
|
||||
cache_hit_count: usize,
|
||||
cache_miss_count: usize,
|
||||
mean_definition_latency: Duration,
|
||||
max_definition_latency: Duration,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
struct Identifier {
|
||||
pub name: String,
|
||||
pub range: Range<Anchor>,
|
||||
}
|
||||
|
||||
enum DefinitionTask {
|
||||
CacheHit(Arc<CacheEntry>),
|
||||
CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CacheEntry {
|
||||
definitions: SmallVec<[CachedDefinition; 1]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EditPredictionContext {
|
||||
pub excerpt: EditPredictionExcerpt,
|
||||
pub excerpt_text: EditPredictionExcerptText,
|
||||
pub cursor_point: Point,
|
||||
pub declarations: Vec<ScoredDeclaration>,
|
||||
struct CachedDefinition {
|
||||
path: ProjectPath,
|
||||
buffer: Entity<Buffer>,
|
||||
anchor_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
impl EditPredictionContext {
|
||||
pub fn gather_context_in_background(
|
||||
cursor_point: Point,
|
||||
buffer: BufferSnapshot,
|
||||
options: EditPredictionContextOptions,
|
||||
syntax_index: Option<Entity<SyntaxIndex>>,
|
||||
cx: &mut App,
|
||||
) -> Task<Option<Self>> {
|
||||
let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
|
||||
let mut path = f.worktree.read(cx).absolutize(&f.path);
|
||||
if path.pop() { Some(path) } else { None }
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedFile {
|
||||
#[serde(serialize_with = "serialize_project_path")]
|
||||
pub path: ProjectPath,
|
||||
#[serde(skip)]
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
pub max_row: u32,
|
||||
}
|
||||
|
||||
impl RelatedFile {
|
||||
pub fn merge_excerpts(&mut self) {
|
||||
self.excerpts.sort_unstable_by(|a, b| {
|
||||
a.point_range
|
||||
.start
|
||||
.cmp(&b.point_range.start)
|
||||
.then(b.point_range.end.cmp(&a.point_range.end))
|
||||
});
|
||||
|
||||
if let Some(syntax_index) = syntax_index {
|
||||
let index_state =
|
||||
syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
|
||||
cx.background_spawn(async move {
|
||||
let parent_abs_path = parent_abs_path.as_deref();
|
||||
let index_state = index_state.upgrade()?;
|
||||
let index_state = index_state.lock().await;
|
||||
Self::gather_context(
|
||||
cursor_point,
|
||||
&buffer,
|
||||
parent_abs_path,
|
||||
&options,
|
||||
Some(&index_state),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
cx.background_spawn(async move {
|
||||
let parent_abs_path = parent_abs_path.as_deref();
|
||||
Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
|
||||
})
|
||||
let mut index = 1;
|
||||
while index < self.excerpts.len() {
|
||||
if self.excerpts[index - 1]
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index].point_range.start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = self.excerpts.remove(index);
|
||||
if removed
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index - 1].point_range.end)
|
||||
.is_gt()
|
||||
{
|
||||
self.excerpts[index - 1].point_range.end = removed.point_range.end;
|
||||
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
#[serde(skip)]
|
||||
pub anchor_range: Range<Anchor>,
|
||||
#[serde(serialize_with = "serialize_point_range")]
|
||||
pub point_range: Range<Point>,
|
||||
#[serde(serialize_with = "serialize_rope")]
|
||||
pub text: Rope,
|
||||
}
|
||||
|
||||
fn serialize_project_path<S: Serializer>(
|
||||
project_path: &ProjectPath,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
project_path.path.serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
rope.to_string().serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_point_range<S: Serializer>(
|
||||
range: &Range<Point>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
[
|
||||
[range.start.row, range.start.column],
|
||||
[range.end.row, range.end.column],
|
||||
]
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
|
||||
|
||||
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
|
||||
|
||||
impl RelatedExcerptStore {
|
||||
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let executor = cx.background_executor().clone();
|
||||
while let Some((mut buffer, mut position)) = update_rx.next().await {
|
||||
let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
|
||||
loop {
|
||||
futures::select_biased! {
|
||||
next = update_rx.next() => {
|
||||
if let Some((new_buffer, new_position)) = next {
|
||||
buffer = new_buffer;
|
||||
position = new_position;
|
||||
timer = executor.timer(DEBOUNCE_DURATION).fuse();
|
||||
} else {
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
}
|
||||
_ = timer => break,
|
||||
}
|
||||
}
|
||||
|
||||
Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
RelatedExcerptStore {
|
||||
project: project.downgrade(),
|
||||
update_tx,
|
||||
related_files: Vec::new(),
|
||||
cache: Default::default(),
|
||||
identifier_line_count: IDENTIFIER_LINE_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gather_context(
|
||||
cursor_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
parent_abs_path: Option<&Path>,
|
||||
options: &EditPredictionContextOptions,
|
||||
index_state: Option<&SyntaxIndexState>,
|
||||
) -> Option<Self> {
|
||||
let imports = if options.use_imports {
|
||||
Imports::gather(&buffer, parent_abs_path)
|
||||
} else {
|
||||
Imports::default()
|
||||
};
|
||||
Self::gather_context_with_references_fn(
|
||||
cursor_point,
|
||||
buffer,
|
||||
&imports,
|
||||
options,
|
||||
index_state,
|
||||
references_in_excerpt,
|
||||
)
|
||||
pub fn set_identifier_line_count(&mut self, count: u32) {
|
||||
self.identifier_line_count = count;
|
||||
}
|
||||
|
||||
pub fn gather_context_with_references_fn(
|
||||
cursor_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
imports: &Imports,
|
||||
options: &EditPredictionContextOptions,
|
||||
index_state: Option<&SyntaxIndexState>,
|
||||
get_references: impl FnOnce(
|
||||
&EditPredictionExcerpt,
|
||||
&EditPredictionExcerptText,
|
||||
&BufferSnapshot,
|
||||
) -> HashMap<Identifier, Vec<Reference>>,
|
||||
) -> Option<Self> {
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
buffer,
|
||||
&options.excerpt,
|
||||
index_state,
|
||||
)?;
|
||||
let excerpt_text = excerpt.text(buffer);
|
||||
pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
|
||||
self.update_tx.unbounded_send((buffer, position)).ok();
|
||||
}
|
||||
|
||||
let declarations = if options.max_retrieved_declarations > 0
|
||||
&& let Some(index_state) = index_state
|
||||
{
|
||||
let excerpt_occurrences =
|
||||
text_similarity::Occurrences::within_string(&excerpt_text.body);
|
||||
pub fn related_files(&self) -> &[RelatedFile] {
|
||||
&self.related_files
|
||||
}
|
||||
|
||||
let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let adjacent_end = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_occurrences = text_similarity::Occurrences::within_string(
|
||||
&buffer
|
||||
.text_for_range(adjacent_start..adjacent_end)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let cursor_offset_in_file = cursor_point.to_offset(buffer);
|
||||
|
||||
let references = get_references(&excerpt, &excerpt_text, buffer);
|
||||
|
||||
let mut declarations = scored_declarations(
|
||||
&options.score,
|
||||
&index_state,
|
||||
&excerpt,
|
||||
&excerpt_occurrences,
|
||||
&adjacent_occurrences,
|
||||
&imports,
|
||||
references,
|
||||
cursor_offset_in_file,
|
||||
buffer,
|
||||
);
|
||||
// TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
|
||||
declarations.truncate(options.max_retrieved_declarations as usize);
|
||||
declarations
|
||||
} else {
|
||||
vec![]
|
||||
async fn fetch_excerpts(
|
||||
this: WeakEntity<Self>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: Anchor,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
|
||||
(
|
||||
this.project.upgrade(),
|
||||
buffer.read(cx).snapshot(),
|
||||
this.identifier_line_count,
|
||||
)
|
||||
})?;
|
||||
let Some(project) = project else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
Some(Self {
|
||||
excerpt,
|
||||
excerpt_text,
|
||||
cursor_point,
|
||||
declarations,
|
||||
})
|
||||
let file = snapshot.file().cloned();
|
||||
if let Some(file) = &file {
|
||||
log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
|
||||
}
|
||||
|
||||
this.update(cx, |_, cx| {
|
||||
cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
|
||||
})?;
|
||||
|
||||
let identifiers = cx
|
||||
.background_spawn(async move {
|
||||
identifiers_for_position(&snapshot, position, identifier_line_count)
|
||||
})
|
||||
.await;
|
||||
|
||||
let async_cx = cx.clone();
|
||||
let start_time = Instant::now();
|
||||
let futures = this.update(cx, |this, cx| {
|
||||
identifiers
|
||||
.into_iter()
|
||||
.filter_map(|identifier| {
|
||||
let task = if let Some(entry) = this.cache.get(&identifier) {
|
||||
DefinitionTask::CacheHit(entry.clone())
|
||||
} else {
|
||||
DefinitionTask::CacheMiss(
|
||||
this.project
|
||||
.update(cx, |project, cx| {
|
||||
project.definitions(&buffer, identifier.range.start, cx)
|
||||
})
|
||||
.ok()?,
|
||||
)
|
||||
};
|
||||
|
||||
let cx = async_cx.clone();
|
||||
let project = project.clone();
|
||||
Some(async move {
|
||||
match task {
|
||||
DefinitionTask::CacheHit(cache_entry) => {
|
||||
Some((identifier, cache_entry, None))
|
||||
}
|
||||
DefinitionTask::CacheMiss(task) => {
|
||||
let locations = task.await.log_err()??;
|
||||
let duration = start_time.elapsed();
|
||||
cx.update(|cx| {
|
||||
(
|
||||
identifier,
|
||||
Arc::new(CacheEntry {
|
||||
definitions: locations
|
||||
.into_iter()
|
||||
.filter_map(|location| {
|
||||
process_definition(location, &project, cx)
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
Some(duration),
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})?;
|
||||
|
||||
let mut cache_hit_count = 0;
|
||||
let mut cache_miss_count = 0;
|
||||
let mut mean_definition_latency = Duration::ZERO;
|
||||
let mut max_definition_latency = Duration::ZERO;
|
||||
let mut new_cache = HashMap::default();
|
||||
new_cache.reserve(futures.len());
|
||||
for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
|
||||
new_cache.insert(identifier, entry);
|
||||
if let Some(duration) = duration {
|
||||
cache_miss_count += 1;
|
||||
mean_definition_latency += duration;
|
||||
max_definition_latency = max_definition_latency.max(duration);
|
||||
} else {
|
||||
cache_hit_count += 1;
|
||||
}
|
||||
}
|
||||
mean_definition_latency /= cache_miss_count.max(1) as u32;
|
||||
|
||||
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
|
||||
|
||||
if let Some(file) = &file {
|
||||
log::debug!(
|
||||
"finished retrieving context buffer:{}, latency:{:?}",
|
||||
file.path().as_unix_str(),
|
||||
start_time.elapsed()
|
||||
);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.cache = new_cache;
|
||||
this.related_files = related_files;
|
||||
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
});
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_call_site(cx: &mut TestAppContext) {
|
||||
let (project, index, _rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// first process_data call site
|
||||
let cursor_point = language::Point::new(8, 21);
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let context = cx
|
||||
.update(|cx| {
|
||||
EditPredictionContext::gather_context_in_background(
|
||||
cursor_point,
|
||||
buffer_snapshot,
|
||||
EditPredictionContextOptions {
|
||||
use_imports: true,
|
||||
excerpt: EditPredictionExcerptOptions {
|
||||
max_bytes: 60,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
score: EditPredictionScoreOptions {
|
||||
omit_excerpt_overlaps: true,
|
||||
},
|
||||
max_retrieved_declarations: u8::MAX,
|
||||
},
|
||||
Some(index.clone()),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut snippet_identifiers = context
|
||||
.declarations
|
||||
.iter()
|
||||
.map(|snippet| snippet.identifier.name.as_ref())
|
||||
.collect::<Vec<_>>();
|
||||
snippet_identifiers.sort();
|
||||
assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
|
||||
drop(buffer);
|
||||
async fn rebuild_related_files(
|
||||
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
|
||||
let mut snapshots = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
|
||||
definition
|
||||
.buffer
|
||||
.read_with(cx, |buffer, _| buffer.parsing_idle())?
|
||||
.await;
|
||||
e.insert(
|
||||
definition
|
||||
.buffer
|
||||
.read_with(cx, |buffer, _| buffer.snapshot())?,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Ok(cx
|
||||
.background_spawn(async move {
|
||||
let mut files = Vec::<RelatedFile>::new();
|
||||
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
|
||||
let mut paths_by_buffer = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
|
||||
ranges_by_buffer
|
||||
.entry(definition.buffer.clone())
|
||||
.or_default()
|
||||
.push(definition.anchor_range.to_point(snapshot));
|
||||
}
|
||||
}
|
||||
|
||||
for (buffer, ranges) in ranges_by_buffer {
|
||||
let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
|
||||
continue;
|
||||
};
|
||||
let excerpts = assemble_excerpts(snapshot, ranges);
|
||||
files.push(RelatedFile {
|
||||
path: project_path.clone(),
|
||||
buffer: buffer.downgrade(),
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
});
|
||||
}
|
||||
|
||||
files.sort_by_key(|file| file.path.clone());
|
||||
(new_entries, files)
|
||||
})
|
||||
.await)
|
||||
}
|
||||
|
||||
fn process_definition(
|
||||
location: LocationLink,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<CachedDefinition> {
|
||||
let buffer = location.target.buffer.read(cx);
|
||||
let anchor_range = location.target.range;
|
||||
let file = buffer.file()?;
|
||||
let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
|
||||
if worktree.read(cx).is_single_file() {
|
||||
return None;
|
||||
}
|
||||
Some(CachedDefinition {
|
||||
path: ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path().clone(),
|
||||
},
|
||||
buffer: location.target.buffer,
|
||||
anchor_range,
|
||||
})
|
||||
}
|
||||
|
||||
/// Gets all of the identifiers that are present in the given line, and its containing
|
||||
/// outline items.
|
||||
fn identifiers_for_position(
|
||||
buffer: &BufferSnapshot,
|
||||
position: Anchor,
|
||||
identifier_line_count: u32,
|
||||
) -> Vec<Identifier> {
|
||||
let offset = position.to_offset(buffer);
|
||||
let point = buffer.offset_to_point(offset);
|
||||
|
||||
// Search for identifiers on lines adjacent to the cursor.
|
||||
let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
|
||||
let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
|
||||
let line_range = start..end;
|
||||
let mut ranges = vec![line_range.to_offset(&buffer)];
|
||||
|
||||
// Search for identifiers mentioned in headers/signatures of containing outline items.
|
||||
let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
|
||||
for item in outline_items {
|
||||
if let Some(body_range) = item.body_range(&buffer) {
|
||||
ranges.push(item.range.start..body_range.start.to_offset(&buffer));
|
||||
} else {
|
||||
ranges.push(item.range.clone());
|
||||
}
|
||||
}
|
||||
|
||||
ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
|
||||
ranges.dedup_by(|a, b| {
|
||||
if a.start <= b.end {
|
||||
b.start = b.start.min(a.start);
|
||||
b.end = b.end.max(a.end);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
let mut identifiers = Vec::new();
|
||||
let outer_range =
|
||||
ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
|
||||
|
||||
let mut captures = buffer
|
||||
.syntax
|
||||
.captures(outer_range.clone(), &buffer.text, |grammar| {
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"a.rs": indoc! {r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = add(x, y);
|
||||
println!("Result: {}", z);
|
||||
}
|
||||
for range in ranges {
|
||||
captures.set_byte_range(range.start..outer_range.end);
|
||||
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"#},
|
||||
"b.rs": indoc! {"
|
||||
pub struct Config {
|
||||
pub name: String,
|
||||
pub value: i32,
|
||||
}
|
||||
let mut last_range = None;
|
||||
while let Some(capture) = captures.peek() {
|
||||
let node_range = capture.node.byte_range();
|
||||
if node_range.start > range.end {
|
||||
break;
|
||||
}
|
||||
let config = captures.grammars()[capture.grammar_index]
|
||||
.highlights_config
|
||||
.as_ref();
|
||||
|
||||
impl Config {
|
||||
pub fn new(name: String, value: i32) -> Self {
|
||||
Config { name, value }
|
||||
}
|
||||
}
|
||||
"},
|
||||
"c.rs": indoc! {r#"
|
||||
use std::collections::HashMap;
|
||||
if let Some(config) = config
|
||||
&& config.identifier_capture_indices.contains(&capture.index)
|
||||
&& range.contains_inclusive(&node_range)
|
||||
&& Some(&node_range) != last_range.as_ref()
|
||||
{
|
||||
let name = buffer.text_for_range(node_range.clone()).collect();
|
||||
identifiers.push(Identifier {
|
||||
range: buffer.anchor_after(node_range.start)
|
||||
..buffer.anchor_before(node_range.end),
|
||||
name,
|
||||
});
|
||||
last_range = Some(node_range);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let data: Vec<i32> = args[1..]
|
||||
.iter()
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
let result = process_data(data);
|
||||
println!("{:?}", result);
|
||||
}
|
||||
|
||||
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for value in data {
|
||||
*counts.entry(value).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_process_data() {
|
||||
let data = vec![1, 2, 2, 3];
|
||||
let result = process_data(data);
|
||||
assert_eq!(result.get(&2), Some(&2));
|
||||
}
|
||||
}
|
||||
"#}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let lang = rust_lang();
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let file_indexing_parallelism = 2;
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
captures.advance();
|
||||
}
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
identifiers
|
||||
}
|
||||
|
||||
@@ -0,0 +1,530 @@
|
||||
use super::*;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageMatcher, Point, ToPoint as _, tree_sitter_rust};
|
||||
use lsp::FakeLanguageServer;
|
||||
use project::{FakeFs, LocationLink, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use std::{fmt::Write as _, sync::Arc};
|
||||
use util::{path, test::marked_text_ranges};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), test_project_1()).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let mut servers = setup_fake_lsp(&project, cx);
|
||||
|
||||
let (buffer, _handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _server = servers.next().await.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
|
||||
related_excerpt_store.update(cx, |store, cx| {
|
||||
let position = {
|
||||
let buffer = buffer.read(cx);
|
||||
let offset = buffer.text().find("todo").unwrap();
|
||||
buffer.anchor_before(offset)
|
||||
};
|
||||
|
||||
store.set_identifier_line_count(0);
|
||||
store.refresh(buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
cx.executor().advance_clock(DEBOUNCE_DURATION);
|
||||
related_excerpt_store.update(cx, |store, _| {
|
||||
let excerpts = store.related_files();
|
||||
assert_related_files(
|
||||
&excerpts,
|
||||
&[
|
||||
(
|
||||
"src/company.rs",
|
||||
&[indoc! {"
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
address: Address,
|
||||
}"}],
|
||||
),
|
||||
(
|
||||
"src/main.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
pub struct Session {
|
||||
company: Arc<Company>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn set_company(&mut self, company: Arc<Company>) {"},
|
||||
indoc! {"
|
||||
}
|
||||
}"},
|
||||
],
|
||||
),
|
||||
(
|
||||
"src/person.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
impl Person {
|
||||
pub fn get_first_name(&self) -> &str {
|
||||
&self.first_name
|
||||
}"},
|
||||
"}",
|
||||
],
|
||||
),
|
||||
],
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_assemble_excerpts(cx: &mut TestAppContext) {
|
||||
let table = [
|
||||
(
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
«last_name»: String,
|
||||
age: u32,
|
||||
email: String,
|
||||
create_at: Instant,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn first_name(&self) -> String {
|
||||
self.first_name.clone()
|
||||
}
|
||||
|
||||
pub fn full_name(&self) -> String {
|
||||
« format!("{} {}", self.first_name, self.last_name)
|
||||
» }
|
||||
}
|
||||
"#},
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
…
|
||||
}
|
||||
|
||||
impl User {
|
||||
…
|
||||
pub fn full_name(&self) -> String {
|
||||
format!("{} {}", self.first_name, self.last_name)
|
||||
}
|
||||
}
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
struct «User» {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
impl User {
|
||||
// methods
|
||||
}
|
||||
"#},
|
||||
indoc! {r#"
|
||||
struct User {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
age: u32,
|
||||
}
|
||||
…
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
trait «FooProvider» {
|
||||
const NAME: &'static str;
|
||||
|
||||
fn provide_foo(&self, id: usize) -> Foo;
|
||||
|
||||
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
ids.iter()
|
||||
.map(|id| self.provide_foo(*id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#
|
||||
},
|
||||
indoc! {r#"
|
||||
trait FooProvider {
|
||||
const NAME: &'static str;
|
||||
|
||||
fn provide_foo(&self, id: usize) -> Foo;
|
||||
|
||||
fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
…
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#},
|
||||
),
|
||||
(
|
||||
indoc! {r#"
|
||||
trait «Something» {
|
||||
fn method1(&self, id: usize) -> Foo;
|
||||
|
||||
fn method2(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
struct Helper1 {
|
||||
field1: usize,
|
||||
}
|
||||
|
||||
struct Helper2 {
|
||||
field2: usize,
|
||||
}
|
||||
|
||||
struct Helper3 {
|
||||
filed2: usize,
|
||||
}
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#
|
||||
},
|
||||
indoc! {r#"
|
||||
trait Something {
|
||||
fn method1(&self, id: usize) -> Foo;
|
||||
|
||||
fn method2(&self, ids: &[usize]) -> Vec<Foo> {
|
||||
…
|
||||
}
|
||||
|
||||
fn sync(&self);
|
||||
}
|
||||
"#},
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_output) in table {
|
||||
let (input, ranges) = marked_text_ranges(&input, false);
|
||||
let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx));
|
||||
buffer.read_with(cx, |buffer, _cx| {
|
||||
let ranges: Vec<Range<Point>> = ranges
|
||||
.into_iter()
|
||||
.map(|range| range.to_point(&buffer))
|
||||
.collect();
|
||||
|
||||
let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
|
||||
|
||||
let output = format_excerpts(buffer, &excerpts);
|
||||
assert_eq!(output, expected_output);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), test_project_1()).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let mut servers = setup_fake_lsp(&project, cx);
|
||||
|
||||
let (buffer, _handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let _server = servers.next().await.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("Address {").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub struct Address {"], cx);
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("State::CA").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub enum State {"], cx);
|
||||
|
||||
let definitions = project
|
||||
.update(cx, |project, cx| {
|
||||
let offset = buffer_text.find("to_string()").unwrap();
|
||||
project.definitions(&buffer, offset, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(|cx| SettingsStore::test(cx));
|
||||
cx.set_global(settings_store);
|
||||
env_logger::try_init().ok();
|
||||
}
|
||||
|
||||
fn setup_fake_lsp(
|
||||
project: &Entity<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> UnboundedReceiver<FakeLanguageServer> {
|
||||
let (language_registry, fs) = project.read_with(cx, |project, _| {
|
||||
(project.languages().clone(), project.fs().clone())
|
||||
});
|
||||
let language = rust_lang();
|
||||
language_registry.add(language.clone());
|
||||
fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs)
|
||||
}
|
||||
|
||||
fn test_project_1() -> serde_json::Value {
|
||||
let person_rs = indoc! {r#"
|
||||
pub struct Person {
|
||||
first_name: String,
|
||||
last_name: String,
|
||||
email: String,
|
||||
age: u32,
|
||||
}
|
||||
|
||||
impl Person {
|
||||
pub fn get_first_name(&self) -> &str {
|
||||
&self.first_name
|
||||
}
|
||||
|
||||
pub fn get_last_name(&self) -> &str {
|
||||
&self.last_name
|
||||
}
|
||||
|
||||
pub fn get_email(&self) -> &str {
|
||||
&self.email
|
||||
}
|
||||
|
||||
pub fn get_age(&self) -> u32 {
|
||||
self.age
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let address_rs = indoc! {r#"
|
||||
pub struct Address {
|
||||
street: String,
|
||||
city: String,
|
||||
state: State,
|
||||
zip: u32,
|
||||
}
|
||||
|
||||
pub enum State {
|
||||
CA,
|
||||
OR,
|
||||
WA,
|
||||
TX,
|
||||
// ...
|
||||
}
|
||||
|
||||
impl Address {
|
||||
pub fn get_street(&self) -> &str {
|
||||
&self.street
|
||||
}
|
||||
|
||||
pub fn get_city(&self) -> &str {
|
||||
&self.city
|
||||
}
|
||||
|
||||
pub fn get_state(&self) -> State {
|
||||
self.state
|
||||
}
|
||||
|
||||
pub fn get_zip(&self) -> u32 {
|
||||
self.zip
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let company_rs = indoc! {r#"
|
||||
use super::person::Person;
|
||||
use super::address::Address;
|
||||
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
address: Address,
|
||||
}
|
||||
|
||||
impl Company {
|
||||
pub fn get_owner(&self) -> &Person {
|
||||
&self.owner
|
||||
}
|
||||
|
||||
pub fn get_address(&self) -> &Address {
|
||||
&self.address
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("{} ({})", self.owner.first_name, self.address.city)
|
||||
}
|
||||
}
|
||||
"#};
|
||||
|
||||
let main_rs = indoc! {r#"
|
||||
use std::sync::Arc;
|
||||
use super::person::Person;
|
||||
use super::address::Address;
|
||||
use super::company::Company;
|
||||
|
||||
pub struct Session {
|
||||
company: Arc<Company>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn set_company(&mut self, company: Arc<Company>) {
|
||||
self.company = company;
|
||||
if company.owner != self.company.owner {
|
||||
log("new owner", company.owner.get_first_name()); todo();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let company = Company {
|
||||
owner: Arc::new(Person {
|
||||
first_name: "John".to_string(),
|
||||
last_name: "Doe".to_string(),
|
||||
email: "john@example.com".to_string(),
|
||||
age: 30,
|
||||
}),
|
||||
address: Address {
|
||||
street: "123 Main St".to_string(),
|
||||
city: "Anytown".to_string(),
|
||||
state: State::CA,
|
||||
zip: 12345,
|
||||
},
|
||||
};
|
||||
|
||||
println!("Company: {}", company.to_string());
|
||||
}
|
||||
"#};
|
||||
|
||||
json!({
|
||||
"src": {
|
||||
"person.rs": person_rs,
|
||||
"address.rs": address_rs,
|
||||
"company.rs": company_rs,
|
||||
"main.rs": main_rs,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) {
|
||||
let actual_files = actual_files
|
||||
.iter()
|
||||
.map(|file| {
|
||||
let excerpts = file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.text.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
(file.path.path.as_unix_str(), excerpts)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_excerpts = expected_files
|
||||
.iter()
|
||||
.map(|(path, texts)| {
|
||||
(
|
||||
*path,
|
||||
texts
|
||||
.iter()
|
||||
.map(|line| line.to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
pretty_assertions::assert_eq!(actual_files, expected_excerpts)
|
||||
}
|
||||
|
||||
fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
|
||||
let actual_first_lines = definitions
|
||||
.iter()
|
||||
.map(|definition| {
|
||||
definition.target.buffer.read_with(cx, |buffer, _| {
|
||||
let mut start = definition.target.range.start.to_point(&buffer);
|
||||
start.column = 0;
|
||||
let end = Point::new(start.row, buffer.line_len(start.row));
|
||||
buffer
|
||||
.text_for_range(start..end)
|
||||
.collect::<String>()
|
||||
.trim()
|
||||
.to_string()
|
||||
})
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
assert_eq!(actual_first_lines, first_lines);
|
||||
}
|
||||
|
||||
fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
|
||||
let mut output = String::new();
|
||||
let file_line_count = buffer.max_point().row;
|
||||
let mut current_row = 0;
|
||||
for excerpt in excerpts {
|
||||
if excerpt.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if current_row < excerpt.point_range.start.row {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
current_row = excerpt.point_range.start.row;
|
||||
|
||||
for line in excerpt.text.to_string().lines() {
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
current_row += 1;
|
||||
}
|
||||
}
|
||||
if current_row < file_line_count {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn rust_lang() -> Arc<Language> {
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
first_line_pattern: None,
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
@@ -1,11 +1,9 @@
|
||||
use language::{BufferSnapshot, LanguageId};
|
||||
use cloud_llm_client::predict_edits_v3::Line;
|
||||
use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _};
|
||||
use std::ops::Range;
|
||||
use text::{Point, ToOffset as _, ToPoint as _};
|
||||
use tree_sitter::{Node, TreeCursor};
|
||||
use util::RangeExt;
|
||||
|
||||
use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// - Test parent signatures
|
||||
@@ -31,19 +29,16 @@ pub struct EditPredictionExcerptOptions {
|
||||
pub target_before_cursor_over_total_bytes: f32,
|
||||
}
|
||||
|
||||
// TODO: consider merging these
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerpt {
|
||||
pub range: Range<usize>,
|
||||
pub line_range: Range<Line>,
|
||||
pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerptText {
|
||||
pub body: String,
|
||||
pub parent_signatures: Vec<String>,
|
||||
pub language_id: Option<LanguageId>,
|
||||
}
|
||||
|
||||
@@ -52,17 +47,8 @@ impl EditPredictionExcerpt {
|
||||
let body = buffer
|
||||
.text_for_range(self.range.clone())
|
||||
.collect::<String>();
|
||||
let parent_signatures = self
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect();
|
||||
let language_id = buffer.language().map(|l| l.id());
|
||||
EditPredictionExcerptText {
|
||||
body,
|
||||
parent_signatures,
|
||||
language_id,
|
||||
}
|
||||
EditPredictionExcerptText { body, language_id }
|
||||
}
|
||||
|
||||
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
|
||||
@@ -79,7 +65,6 @@ impl EditPredictionExcerpt {
|
||||
query_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
options: &EditPredictionExcerptOptions,
|
||||
syntax_index: Option<&SyntaxIndexState>,
|
||||
) -> Option<Self> {
|
||||
if buffer.len() <= options.max_bytes {
|
||||
log::debug!(
|
||||
@@ -89,11 +74,7 @@ impl EditPredictionExcerpt {
|
||||
);
|
||||
let offset_range = 0..buffer.len();
|
||||
let line_range = Line(0)..Line(buffer.max_point().row);
|
||||
return Some(EditPredictionExcerpt::new(
|
||||
offset_range,
|
||||
line_range,
|
||||
Vec::new(),
|
||||
));
|
||||
return Some(EditPredictionExcerpt::new(offset_range, line_range));
|
||||
}
|
||||
|
||||
let query_offset = query_point.to_offset(buffer);
|
||||
@@ -104,19 +85,10 @@ impl EditPredictionExcerpt {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parent_declarations = if let Some(syntax_index) = syntax_index {
|
||||
syntax_index
|
||||
.buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let excerpt_selector = ExcerptSelector {
|
||||
query_offset,
|
||||
query_range,
|
||||
query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
|
||||
parent_declarations: &parent_declarations,
|
||||
buffer,
|
||||
options,
|
||||
};
|
||||
@@ -139,20 +111,10 @@ impl EditPredictionExcerpt {
|
||||
excerpt_selector.select_lines()
|
||||
}
|
||||
|
||||
fn new(
|
||||
range: Range<usize>,
|
||||
line_range: Range<Line>,
|
||||
parent_declarations: Vec<(DeclarationId, Range<usize>)>,
|
||||
) -> Self {
|
||||
let size = range.len()
|
||||
+ parent_declarations
|
||||
.iter()
|
||||
.map(|(_, range)| range.len())
|
||||
.sum::<usize>();
|
||||
fn new(range: Range<usize>, line_range: Range<Line>) -> Self {
|
||||
Self {
|
||||
size: range.len(),
|
||||
range,
|
||||
parent_declarations,
|
||||
size,
|
||||
line_range,
|
||||
}
|
||||
}
|
||||
@@ -162,14 +124,7 @@ impl EditPredictionExcerpt {
|
||||
// this is an issue because parent_signature_ranges may be incorrect
|
||||
log::error!("bug: with_expanded_range called with disjoint range");
|
||||
}
|
||||
let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
|
||||
for (declaration_id, range) in &self.parent_declarations {
|
||||
if !range.contains_inclusive(&new_range) {
|
||||
break;
|
||||
}
|
||||
parent_declarations.push((*declaration_id, range.clone()));
|
||||
}
|
||||
Self::new(new_range, new_line_range, parent_declarations)
|
||||
Self::new(new_range, new_line_range)
|
||||
}
|
||||
|
||||
fn parent_signatures_size(&self) -> usize {
|
||||
@@ -181,7 +136,6 @@ struct ExcerptSelector<'a> {
|
||||
query_offset: usize,
|
||||
query_range: Range<usize>,
|
||||
query_line_range: Range<Line>,
|
||||
parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
|
||||
buffer: &'a BufferSnapshot,
|
||||
options: &'a EditPredictionExcerptOptions,
|
||||
}
|
||||
@@ -409,13 +363,7 @@ impl<'a> ExcerptSelector<'a> {
|
||||
}
|
||||
|
||||
fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
|
||||
let parent_declarations = self
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
|
||||
.map(|(id, declaration)| (*id, declaration.signature_range.clone()))
|
||||
.collect();
|
||||
EditPredictionExcerpt::new(range, line_range, parent_declarations)
|
||||
EditPredictionExcerpt::new(range, line_range)
|
||||
}
|
||||
|
||||
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
|
||||
@@ -506,9 +454,8 @@ mod tests {
|
||||
let buffer = create_buffer(&text, cx);
|
||||
let cursor_point = cursor.to_point(&buffer);
|
||||
|
||||
let excerpt =
|
||||
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
|
||||
.expect("Should select an excerpt");
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
|
||||
.expect("Should select an excerpt");
|
||||
pretty_assertions::assert_eq!(
|
||||
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
|
||||
generate_marked_text(&text, &[expected_excerpt], false)
|
||||
|
||||
329
crates/edit_prediction_context/src/fake_definition_lsp.rs
Normal file
329
crates/edit_prediction_context/src/fake_definition_lsp.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use collections::HashMap;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use language::{Language, LanguageRegistry};
|
||||
use lsp::{
|
||||
FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use project::Fs;
|
||||
use std::{ops::Range, path::PathBuf, sync::Arc};
|
||||
use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree};
|
||||
|
||||
/// Registers a fake language server that implements go-to-definition using tree-sitter,
|
||||
/// making the assumption that all names are unique, and all variables' types are
|
||||
/// explicitly declared.
|
||||
pub fn register_fake_definition_server(
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
language: Arc<Language>,
|
||||
fs: Arc<dyn Fs>,
|
||||
) -> UnboundedReceiver<FakeLanguageServer> {
|
||||
let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone())));
|
||||
|
||||
language_registry.register_fake_lsp(
|
||||
language.name(),
|
||||
language::FakeLspAdapter {
|
||||
name: "fake-definition-lsp",
|
||||
initialization_options: None,
|
||||
prettier_plugins: Vec::new(),
|
||||
disk_based_diagnostics_progress_token: None,
|
||||
disk_based_diagnostics_sources: Vec::new(),
|
||||
language_server_binary: LanguageServerBinary {
|
||||
path: PathBuf::from("fake-definition-lsp"),
|
||||
arguments: Vec::new(),
|
||||
env: None,
|
||||
},
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
definition_provider: Some(lsp::OneOf::Left(true)),
|
||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||
TextDocumentSyncKind::FULL,
|
||||
)),
|
||||
..Default::default()
|
||||
},
|
||||
label_for_completion: None,
|
||||
initializer: Some(Box::new({
|
||||
move |server| {
|
||||
server.handle_notification::<lsp::notification::DidOpenTextDocument, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
index
|
||||
.lock()
|
||||
.open_buffer(params.text_document.uri, ¶ms.text_document.text);
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidCloseTextDocument, _>({
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let uri = params.text_document.uri;
|
||||
let path = uri.to_file_path().ok();
|
||||
index.lock().mark_buffer_closed(&uri);
|
||||
|
||||
if let Some(path) = path {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
cx.spawn(async move |_cx| {
|
||||
if let Ok(content) = fs.load(&path).await {
|
||||
index.lock().index_file(uri, &content);
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeWatchedFiles, _>({
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
cx.spawn(async move |_cx| {
|
||||
for event in params.changes {
|
||||
if index.lock().is_buffer_open(&event.uri) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match event.typ {
|
||||
lsp::FileChangeType::DELETED => {
|
||||
index.lock().remove_definitions_for_file(&event.uri);
|
||||
}
|
||||
lsp::FileChangeType::CREATED
|
||||
| lsp::FileChangeType::CHANGED => {
|
||||
if let Some(path) = event.uri.to_file_path().ok() {
|
||||
if let Ok(content) = fs.load(&path).await {
|
||||
index.lock().index_file(event.uri, &content);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeTextDocument, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
if let Some(change) = params.content_changes.into_iter().last() {
|
||||
index
|
||||
.lock()
|
||||
.index_file(params.text_document.uri, &change.text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.handle_notification::<lsp::notification::DidChangeWorkspaceFolders, _>(
|
||||
{
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
move |params, cx| {
|
||||
let index = index.clone();
|
||||
let fs = fs.clone();
|
||||
let files = fs.as_fake().files();
|
||||
cx.spawn(async move |_cx| {
|
||||
for folder in params.event.added {
|
||||
let Ok(path) = folder.uri.to_file_path() else {
|
||||
continue;
|
||||
};
|
||||
for file in &files {
|
||||
if let Some(uri) = Uri::from_file_path(&file).ok()
|
||||
&& file.starts_with(&path)
|
||||
&& let Ok(content) = fs.load(&file).await
|
||||
{
|
||||
index.lock().index_file(uri, &content);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
server.set_request_handler::<lsp::request::GotoDefinition, _, _>({
|
||||
let index = index.clone();
|
||||
move |params, _cx| {
|
||||
let result = index.lock().get_definitions(
|
||||
params.text_document_position_params.text_document.uri,
|
||||
params.text_document_position_params.position,
|
||||
);
|
||||
async move { Ok(result) }
|
||||
}
|
||||
});
|
||||
}
|
||||
})),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
struct DefinitionIndex {
|
||||
language: Arc<Language>,
|
||||
definitions: HashMap<String, Vec<lsp::Location>>,
|
||||
files: HashMap<Uri, FileEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FileEntry {
|
||||
contents: String,
|
||||
is_open_in_buffer: bool,
|
||||
}
|
||||
|
||||
impl DefinitionIndex {
|
||||
fn new(language: Arc<Language>) -> Self {
|
||||
Self {
|
||||
language,
|
||||
definitions: HashMap::default(),
|
||||
files: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_definitions_for_file(&mut self, uri: &Uri) {
|
||||
self.definitions.retain(|_, locations| {
|
||||
locations.retain(|loc| &loc.uri != uri);
|
||||
!locations.is_empty()
|
||||
});
|
||||
self.files.remove(uri);
|
||||
}
|
||||
|
||||
fn open_buffer(&mut self, uri: Uri, content: &str) {
|
||||
self.index_file_inner(uri, content, true);
|
||||
}
|
||||
|
||||
fn mark_buffer_closed(&mut self, uri: &Uri) {
|
||||
if let Some(entry) = self.files.get_mut(uri) {
|
||||
entry.is_open_in_buffer = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn is_buffer_open(&self, uri: &Uri) -> bool {
|
||||
self.files
|
||||
.get(uri)
|
||||
.map(|entry| entry.is_open_in_buffer)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn index_file(&mut self, uri: Uri, content: &str) {
|
||||
self.index_file_inner(uri, content, false);
|
||||
}
|
||||
|
||||
fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> {
|
||||
self.remove_definitions_for_file(&uri);
|
||||
let grammar = self.language.grammar()?;
|
||||
let outline_config = grammar.outline_config.as_ref()?;
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(&grammar.ts_language).ok()?;
|
||||
let tree = parser.parse(content, None)?;
|
||||
let declarations = extract_declarations_from_tree(&tree, content, outline_config);
|
||||
for (name, byte_range) in declarations {
|
||||
let range = byte_range_to_lsp_range(content, byte_range);
|
||||
let location = lsp::Location {
|
||||
uri: uri.clone(),
|
||||
range,
|
||||
};
|
||||
self.definitions
|
||||
.entry(name)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(location);
|
||||
}
|
||||
self.files.insert(
|
||||
uri,
|
||||
FileEntry {
|
||||
contents: content.to_string(),
|
||||
is_open_in_buffer,
|
||||
},
|
||||
);
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
fn get_definitions(
|
||||
&mut self,
|
||||
uri: Uri,
|
||||
position: lsp::Position,
|
||||
) -> Option<lsp::GotoDefinitionResponse> {
|
||||
let entry = self.files.get(&uri)?;
|
||||
let name = word_at_position(&entry.contents, position)?;
|
||||
let locations = self.definitions.get(name).cloned()?;
|
||||
Some(lsp::GotoDefinitionResponse::Array(locations))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_declarations_from_tree(
|
||||
tree: &Tree,
|
||||
content: &str,
|
||||
outline_config: &language::OutlineConfig,
|
||||
) -> Vec<(String, Range<usize>)> {
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut declarations = Vec::new();
|
||||
let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes());
|
||||
while let Some(query_match) = matches.next() {
|
||||
let mut name_range: Option<Range<usize>> = None;
|
||||
let mut has_item_range = false;
|
||||
|
||||
for capture in query_match.captures {
|
||||
let range = capture.node.byte_range();
|
||||
if capture.index == outline_config.name_capture_ix {
|
||||
name_range = Some(range);
|
||||
} else if capture.index == outline_config.item_capture_ix {
|
||||
has_item_range = true;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(name_range) = name_range
|
||||
&& has_item_range
|
||||
{
|
||||
let name = content[name_range.clone()].to_string();
|
||||
if declarations.iter().any(|(n, _)| n == &name) {
|
||||
continue;
|
||||
}
|
||||
declarations.push((name, name_range));
|
||||
}
|
||||
}
|
||||
declarations
|
||||
}
|
||||
|
||||
fn byte_range_to_lsp_range(content: &str, byte_range: Range<usize>) -> lsp::Range {
|
||||
let start = byte_offset_to_position(content, byte_range.start);
|
||||
let end = byte_offset_to_position(content, byte_range.end);
|
||||
lsp::Range { start, end }
|
||||
}
|
||||
|
||||
fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position {
|
||||
let mut line = 0;
|
||||
let mut character = 0;
|
||||
let mut current_offset = 0;
|
||||
for ch in content.chars() {
|
||||
if current_offset >= offset {
|
||||
break;
|
||||
}
|
||||
if ch == '\n' {
|
||||
line += 1;
|
||||
character = 0;
|
||||
} else {
|
||||
character += 1;
|
||||
}
|
||||
current_offset += ch.len_utf8();
|
||||
}
|
||||
lsp::Position { line, character }
|
||||
}
|
||||
|
||||
fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> {
|
||||
let mut lines = content.lines();
|
||||
let line = lines.nth(position.line as usize)?;
|
||||
let column = position.character as usize;
|
||||
if column > line.len() {
|
||||
return None;
|
||||
}
|
||||
let start = line[..column]
|
||||
.rfind(|c: char| !c.is_alphanumeric() && c != '_')
|
||||
.map(|i| i + 1)
|
||||
.unwrap_or(0);
|
||||
let end = line[column..]
|
||||
.find(|c: char| !c.is_alphanumeric() && c != '_')
|
||||
.map(|i| i + column)
|
||||
.unwrap_or(line.len());
|
||||
Some(&line[start..end]).filter(|word| !word.is_empty())
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,126 +0,0 @@
|
||||
use language::{BufferSnapshot, SyntaxMapMatches};
|
||||
use std::{cmp::Reverse, ops::Range};
|
||||
|
||||
use crate::declaration::Identifier;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * how to handle multiple name captures? for now last one wins
|
||||
//
|
||||
// * annotation ranges
|
||||
//
|
||||
// * new "signature" capture for outline queries
|
||||
//
|
||||
// * Check parent behavior of "int x, y = 0" declarations in a test
|
||||
|
||||
pub struct OutlineDeclaration {
|
||||
pub parent_index: Option<usize>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub signature_range: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
|
||||
declarations_overlapping_range(0..buffer.len(), buffer)
|
||||
}
|
||||
|
||||
pub fn declarations_overlapping_range(
|
||||
range: Range<usize>,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Vec<OutlineDeclaration> {
|
||||
let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
|
||||
declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
|
||||
|
||||
let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
|
||||
for (index, declaration) in declarations.iter_mut().enumerate() {
|
||||
while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
|
||||
if declaration.item_range.start >= top_parent_range.end {
|
||||
parent_stack.pop();
|
||||
} else {
|
||||
declaration.parent_index = Some(*top_parent_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
parent_stack.push((index, declaration.item_range.clone()));
|
||||
}
|
||||
declarations
|
||||
}
|
||||
|
||||
/// Iterates outline items without being ordered w.r.t. nested items and without populating
|
||||
/// `parent`.
|
||||
pub struct OutlineIterator<'a> {
|
||||
buffer: &'a BufferSnapshot,
|
||||
matches: SyntaxMapMatches<'a>,
|
||||
}
|
||||
|
||||
impl<'a> OutlineIterator<'a> {
|
||||
pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
|
||||
let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
|
||||
grammar.outline_config.as_ref().map(|c| &c.query)
|
||||
});
|
||||
|
||||
Self { buffer, matches }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for OutlineIterator<'a> {
|
||||
type Item = OutlineDeclaration;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(mat) = self.matches.peek() {
|
||||
let config = self.matches.grammars()[mat.grammar_index]
|
||||
.outline_config
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let mut name_range = None;
|
||||
let mut item_range = None;
|
||||
let mut signature_start = None;
|
||||
let mut signature_end = None;
|
||||
|
||||
let mut add_to_signature = |range: Range<usize>| {
|
||||
if signature_start.is_none() {
|
||||
signature_start = Some(range.start);
|
||||
}
|
||||
signature_end = Some(range.end);
|
||||
};
|
||||
|
||||
for capture in mat.captures {
|
||||
let range = capture.node.byte_range();
|
||||
if capture.index == config.name_capture_ix {
|
||||
name_range = Some(range.clone());
|
||||
add_to_signature(range);
|
||||
} else if Some(capture.index) == config.context_capture_ix
|
||||
|| Some(capture.index) == config.extra_context_capture_ix
|
||||
{
|
||||
add_to_signature(range);
|
||||
} else if capture.index == config.item_capture_ix {
|
||||
item_range = Some(range.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let language_id = mat.language.id();
|
||||
self.matches.advance();
|
||||
|
||||
if let Some(name_range) = name_range
|
||||
&& let Some(item_range) = item_range
|
||||
&& let Some(signature_start) = signature_start
|
||||
&& let Some(signature_end) = signature_end
|
||||
{
|
||||
let name = self
|
||||
.buffer
|
||||
.text_for_range(name_range)
|
||||
.collect::<String>()
|
||||
.into();
|
||||
|
||||
return Some(OutlineDeclaration {
|
||||
identifier: Identifier { name, language_id },
|
||||
item_range: item_range,
|
||||
signature_range: signature_start..signature_end,
|
||||
parent_index: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
use collections::HashMap;
|
||||
use language::BufferSnapshot;
|
||||
use std::ops::Range;
|
||||
use util::RangeExt;
|
||||
|
||||
use crate::{
|
||||
declaration::Identifier,
|
||||
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Reference {
|
||||
pub identifier: Identifier,
|
||||
pub range: Range<usize>,
|
||||
pub region: ReferenceRegion,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ReferenceRegion {
|
||||
Breadcrumb,
|
||||
Nearby,
|
||||
}
|
||||
|
||||
pub fn references_in_excerpt(
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_text: &EditPredictionExcerptText,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> HashMap<Identifier, Vec<Reference>> {
|
||||
let mut references = references_in_range(
|
||||
excerpt.range.clone(),
|
||||
excerpt_text.body.as_str(),
|
||||
ReferenceRegion::Nearby,
|
||||
snapshot,
|
||||
);
|
||||
|
||||
for ((_, range), text) in excerpt
|
||||
.parent_declarations
|
||||
.iter()
|
||||
.zip(excerpt_text.parent_signatures.iter())
|
||||
{
|
||||
references.extend(references_in_range(
|
||||
range.clone(),
|
||||
text.as_str(),
|
||||
ReferenceRegion::Breadcrumb,
|
||||
snapshot,
|
||||
));
|
||||
}
|
||||
|
||||
let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::default();
|
||||
for reference in references {
|
||||
identifier_to_references
|
||||
.entry(reference.identifier.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(reference);
|
||||
}
|
||||
identifier_to_references
|
||||
}
|
||||
|
||||
/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
|
||||
pub fn references_in_range(
|
||||
range: Range<usize>,
|
||||
range_text: &str,
|
||||
reference_region: ReferenceRegion,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Vec<Reference> {
|
||||
let mut matches = buffer
|
||||
.syntax
|
||||
.matches(range.clone(), &buffer.text, |grammar| {
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
|
||||
let mut references = Vec::new();
|
||||
let mut last_added_range = None;
|
||||
while let Some(mat) = matches.peek() {
|
||||
let config = matches.grammars()[mat.grammar_index]
|
||||
.highlights_config
|
||||
.as_ref();
|
||||
|
||||
if let Some(config) = config {
|
||||
for capture in mat.captures {
|
||||
if config.identifier_capture_indices.contains(&capture.index) {
|
||||
let node_range = capture.node.byte_range();
|
||||
|
||||
// sometimes multiple highlight queries match - this deduplicates them
|
||||
if Some(node_range.clone()) == last_added_range {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !range.contains_inclusive(&node_range) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let identifier_text =
|
||||
&range_text[node_range.start - range.start..node_range.end - range.start];
|
||||
|
||||
references.push(Reference {
|
||||
identifier: Identifier {
|
||||
name: identifier_text.into(),
|
||||
language_id: mat.language.id(),
|
||||
},
|
||||
range: node_range.clone(),
|
||||
region: reference_region,
|
||||
});
|
||||
last_added_range = Some(node_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matches.advance();
|
||||
}
|
||||
references
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use gpui::{TestAppContext, prelude::*};
|
||||
use indoc::indoc;
|
||||
use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
|
||||
use crate::reference::{ReferenceRegion, references_in_range};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_identifier_node_truncated(cx: &mut TestAppContext) {
|
||||
let code = indoc! { r#"
|
||||
fn main() {
|
||||
add(1, 2);
|
||||
}
|
||||
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"# };
|
||||
let buffer = create_buffer(code, cx);
|
||||
|
||||
let range = 0..35;
|
||||
let references = references_in_range(
|
||||
range.clone(),
|
||||
&code[range],
|
||||
ReferenceRegion::Breadcrumb,
|
||||
&buffer,
|
||||
);
|
||||
assert_eq!(references.len(), 2);
|
||||
assert_eq!(references[0].identifier.name.as_ref(), "main");
|
||||
assert_eq!(references[1].identifier.name.as_ref(), "add");
|
||||
}
|
||||
|
||||
fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
|
||||
let buffer =
|
||||
cx.new(|cx| language::Buffer::local(text, cx).with_language(rust_lang().into(), cx));
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot())
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,314 +0,0 @@
|
||||
use hashbrown::HashTable;
|
||||
use regex::Regex;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
hash::{Hash, Hasher as _},
|
||||
path::Path,
|
||||
sync::LazyLock,
|
||||
};
|
||||
use util::rel_path::RelPath;
|
||||
|
||||
use crate::reference::Reference;
|
||||
|
||||
// TODO: Consider implementing sliding window similarity matching like
|
||||
// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
|
||||
//
|
||||
// That implementation could actually be more efficient - no need to track words in the window that
|
||||
// are not in the query.
|
||||
|
||||
// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
|
||||
// two in parallel.
|
||||
|
||||
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
|
||||
|
||||
/// Multiset of text occurrences for text similarity that only stores hashes and counts.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Occurrences {
|
||||
table: HashTable<OccurrenceEntry>,
|
||||
total_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OccurrenceEntry {
|
||||
hash: u64,
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl Occurrences {
|
||||
pub fn within_string(text: &str) -> Self {
|
||||
Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn within_references(references: &[Reference]) -> Self {
|
||||
Self::from_identifiers(
|
||||
references
|
||||
.iter()
|
||||
.map(|reference| reference.identifier.name.as_ref()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_identifiers(identifiers: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
|
||||
let mut this = Self::default();
|
||||
// TODO: Score matches that match case higher?
|
||||
//
|
||||
// TODO: Also include unsplit identifier?
|
||||
for identifier in identifiers {
|
||||
for identifier_part in split_identifier(identifier.as_ref()) {
|
||||
this.add_hash(fx_hash(&identifier_part.to_lowercase()));
|
||||
}
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
pub fn from_worktree_path(worktree_name: Option<Cow<'_, str>>, rel_path: &RelPath) -> Self {
|
||||
if let Some(worktree_name) = worktree_name {
|
||||
Self::from_identifiers(
|
||||
std::iter::once(worktree_name)
|
||||
.chain(iter_path_without_extension(rel_path.as_std_path())),
|
||||
)
|
||||
} else {
|
||||
Self::from_path(rel_path.as_std_path())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_path(path: &Path) -> Self {
|
||||
Self::from_identifiers(iter_path_without_extension(path))
|
||||
}
|
||||
|
||||
fn add_hash(&mut self, hash: u64) {
|
||||
self.table
|
||||
.entry(
|
||||
hash,
|
||||
|entry: &OccurrenceEntry| entry.hash == hash,
|
||||
|entry| entry.hash,
|
||||
)
|
||||
.and_modify(|entry| entry.count += 1)
|
||||
.or_insert(OccurrenceEntry { hash, count: 1 });
|
||||
self.total_count += 1;
|
||||
}
|
||||
|
||||
fn contains_hash(&self, hash: u64) -> bool {
|
||||
self.get_count(hash) != 0
|
||||
}
|
||||
|
||||
fn get_count(&self, hash: u64) -> usize {
|
||||
self.table
|
||||
.find(hash, |entry| entry.hash == hash)
|
||||
.map(|entry| entry.count)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn iter_path_without_extension(path: &Path) -> impl Iterator<Item = Cow<'_, str>> {
|
||||
let last_component: Option<Cow<'_, str>> = path.file_stem().map(|stem| stem.to_string_lossy());
|
||||
let mut path_components = path.components();
|
||||
path_components.next_back();
|
||||
path_components
|
||||
.map(|component| component.as_os_str().to_string_lossy())
|
||||
.chain(last_component)
|
||||
}
|
||||
|
||||
pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
|
||||
let mut hasher = collections::FxHasher::default();
|
||||
data.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// Splits camelcase / snakecase / kebabcase / pascalcase
|
||||
//
|
||||
// TODO: Make this more efficient / elegant.
|
||||
fn split_identifier(identifier: &str) -> Vec<&str> {
|
||||
let mut parts = Vec::new();
|
||||
let mut start = 0;
|
||||
let chars: Vec<char> = identifier.chars().collect();
|
||||
|
||||
if chars.is_empty() {
|
||||
return parts;
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
// Handle explicit delimiters (underscore and hyphen)
|
||||
if ch == '_' || ch == '-' {
|
||||
if i > start {
|
||||
parts.push(&identifier[start..i]);
|
||||
}
|
||||
start = i + 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle camelCase and PascalCase transitions
|
||||
if i > 0 && i < chars.len() {
|
||||
let prev_char = chars[i - 1];
|
||||
|
||||
// Transition from lowercase/digit to uppercase
|
||||
if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
// Handle sequences like "XMLParser" -> ["XML", "Parser"]
|
||||
else if i + 1 < chars.len()
|
||||
&& ch.is_uppercase()
|
||||
&& chars[i + 1].is_lowercase()
|
||||
&& prev_char.is_uppercase()
|
||||
{
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Add the last part if there's any remaining
|
||||
if start < identifier.len() {
|
||||
parts.push(&identifier[start..]);
|
||||
}
|
||||
|
||||
// Filter out empty strings
|
||||
parts.into_iter().filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.table
|
||||
.iter()
|
||||
.filter(|entry| set_b.contains_hash(entry.hash))
|
||||
.count();
|
||||
let union = set_a.table.len() + set_b.table.len() - intersection;
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.table
|
||||
.iter()
|
||||
.filter(|entry| set_b.contains_hash(entry.hash))
|
||||
.count();
|
||||
intersection as f32 / set_a.table.len() as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn weighted_jaccard_similarity<'a>(
|
||||
mut set_a: &'a Occurrences,
|
||||
mut set_b: &'a Occurrences,
|
||||
) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
let mut denominator_a = 0;
|
||||
let mut used_count_b = 0;
|
||||
for entry_a in set_a.table.iter() {
|
||||
let count_a = entry_a.count;
|
||||
let count_b = set_b.get_count(entry_a.hash);
|
||||
numerator += count_a.min(count_b);
|
||||
denominator_a += count_a.max(count_b);
|
||||
used_count_b += count_b;
|
||||
}
|
||||
|
||||
let denominator = denominator_a + (set_b.total_count - used_count_b);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weighted_overlap_coefficient<'a>(
|
||||
mut set_a: &'a Occurrences,
|
||||
mut set_b: &'a Occurrences,
|
||||
) -> f32 {
|
||||
if set_a.table.len() > set_b.table.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
for entry_a in set_a.table.iter() {
|
||||
let count_a = entry_a.count;
|
||||
let count_b = set_b.get_count(entry_a.hash);
|
||||
numerator += count_a.min(count_b);
|
||||
}
|
||||
|
||||
let denominator = set_a.total_count.min(set_b.total_count);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_identifier() {
|
||||
assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
|
||||
assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
|
||||
assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
|
||||
assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
|
||||
assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_functions() {
|
||||
// 10 identifier parts, 8 unique
|
||||
// Repeats: 2 "outline", 2 "items"
|
||||
let set_a = Occurrences::within_string(
|
||||
"let mut outline_items = query_outline_items(&language, &tree, &source);",
|
||||
);
|
||||
// 14 identifier parts, 11 unique
|
||||
// Repeats: 2 "outline", 2 "language", 2 "tree"
|
||||
let set_b = Occurrences::within_string(
|
||||
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
|
||||
);
|
||||
|
||||
// 6 overlaps: "outline", "items", "query", "language", "tree", "source"
|
||||
// 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
|
||||
assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
|
||||
|
||||
// Numerator is one more than before due to both having 2 "outline".
|
||||
// Denominator is the same except for 3 more due to the non-overlapping duplicates
|
||||
assert_eq!(
|
||||
weighted_jaccard_similarity(&set_a, &set_b),
|
||||
7.0 / (7.0 + 7.0 + 3.0)
|
||||
);
|
||||
|
||||
// Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
|
||||
assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
|
||||
|
||||
// Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
|
||||
// the smaller set, 10.
|
||||
assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter_path_without_extension() {
|
||||
let mut iter = iter_path_without_extension(Path::new(""));
|
||||
assert_eq!(iter.next(), None);
|
||||
|
||||
let iter = iter_path_without_extension(Path::new("foo"));
|
||||
assert_eq!(iter.collect::<Vec<_>>(), ["foo"]);
|
||||
|
||||
let iter = iter_path_without_extension(Path::new("foo/bar.txt"));
|
||||
assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar"]);
|
||||
|
||||
let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt"));
|
||||
assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar", "baz"]);
|
||||
}
|
||||
}
|
||||
17
crates/edit_prediction_types/Cargo.toml
Normal file
17
crates/edit_prediction_types/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "edit_prediction_types"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/edit_prediction_types.rs"
|
||||
|
||||
[dependencies]
|
||||
client.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
298
crates/edit_prediction_types/src/edit_prediction_types.rs
Normal file
298
crates/edit_prediction_types/src/edit_prediction_types.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use client::EditPredictionUsage;
|
||||
use gpui::{App, Context, Entity, SharedString};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
|
||||
|
||||
// TODO: Find a better home for `Direction`.
|
||||
//
|
||||
// This should live in an ancestor crate of `editor` and `edit_prediction`,
|
||||
// but at time of writing there isn't an obvious spot.
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Direction {
|
||||
Prev,
|
||||
Next,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum EditPrediction {
|
||||
/// Edits within the buffer that requested the prediction
|
||||
Local {
|
||||
id: Option<SharedString>,
|
||||
edits: Vec<(Range<language::Anchor>, Arc<str>)>,
|
||||
edit_preview: Option<language::EditPreview>,
|
||||
},
|
||||
/// Jump to a different file from the one that requested the prediction
|
||||
Jump {
|
||||
id: Option<SharedString>,
|
||||
snapshot: language::BufferSnapshot,
|
||||
target: language::Anchor,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum DataCollectionState {
|
||||
/// The provider doesn't support data collection.
|
||||
Unsupported,
|
||||
/// Data collection is enabled.
|
||||
Enabled { is_project_open_source: bool },
|
||||
/// Data collection is disabled or unanswered.
|
||||
Disabled { is_project_open_source: bool },
|
||||
}
|
||||
|
||||
impl DataCollectionState {
|
||||
pub fn is_supported(&self) -> bool {
|
||||
!matches!(self, DataCollectionState::Unsupported)
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
matches!(self, DataCollectionState::Enabled { .. })
|
||||
}
|
||||
|
||||
pub fn is_project_open_source(&self) -> bool {
|
||||
match self {
|
||||
Self::Enabled {
|
||||
is_project_open_source,
|
||||
}
|
||||
| Self::Disabled {
|
||||
is_project_open_source,
|
||||
} => *is_project_open_source,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait EditPredictionDelegate: 'static + Sized {
|
||||
fn name() -> &'static str;
|
||||
fn display_name() -> &'static str;
|
||||
fn show_predictions_in_menu() -> bool;
|
||||
fn show_tab_accept_marker() -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_jump_to_edit() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
|
||||
DataCollectionState::Unsupported
|
||||
}
|
||||
|
||||
fn usage(&self, _cx: &App) -> Option<EditPredictionUsage> {
|
||||
None
|
||||
}
|
||||
|
||||
fn toggle_data_collection(&mut self, _cx: &mut App) {}
|
||||
fn is_enabled(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool;
|
||||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
);
|
||||
fn cycle(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
direction: Direction,
|
||||
cx: &mut Context<Self>,
|
||||
);
|
||||
fn accept(&mut self, cx: &mut Context<Self>);
|
||||
fn discard(&mut self, cx: &mut Context<Self>);
|
||||
fn did_show(&mut self, _cx: &mut Context<Self>) {}
|
||||
fn suggest(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<EditPrediction>;
|
||||
}
|
||||
|
||||
pub trait EditPredictionDelegateHandle {
|
||||
fn name(&self) -> &'static str;
|
||||
fn display_name(&self) -> &'static str;
|
||||
fn is_enabled(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool;
|
||||
fn show_predictions_in_menu(&self) -> bool;
|
||||
fn show_tab_accept_marker(&self) -> bool;
|
||||
fn supports_jump_to_edit(&self) -> bool;
|
||||
fn data_collection_state(&self, cx: &App) -> DataCollectionState;
|
||||
fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
|
||||
fn toggle_data_collection(&self, cx: &mut App);
|
||||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut App,
|
||||
);
|
||||
fn cycle(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
direction: Direction,
|
||||
cx: &mut App,
|
||||
);
|
||||
fn did_show(&self, cx: &mut App);
|
||||
fn accept(&self, cx: &mut App);
|
||||
fn discard(&self, cx: &mut App);
|
||||
fn suggest(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut App,
|
||||
) -> Option<EditPrediction>;
|
||||
}
|
||||
|
||||
impl<T> EditPredictionDelegateHandle for Entity<T>
|
||||
where
|
||||
T: EditPredictionDelegate,
|
||||
{
|
||||
fn name(&self) -> &'static str {
|
||||
T::name()
|
||||
}
|
||||
|
||||
fn display_name(&self) -> &'static str {
|
||||
T::display_name()
|
||||
}
|
||||
|
||||
fn show_predictions_in_menu(&self) -> bool {
|
||||
T::show_predictions_in_menu()
|
||||
}
|
||||
|
||||
fn show_tab_accept_marker(&self) -> bool {
|
||||
T::show_tab_accept_marker()
|
||||
}
|
||||
|
||||
fn supports_jump_to_edit(&self) -> bool {
|
||||
T::supports_jump_to_edit()
|
||||
}
|
||||
|
||||
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
|
||||
self.read(cx).data_collection_state(cx)
|
||||
}
|
||||
|
||||
fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||
self.read(cx).usage(cx)
|
||||
}
|
||||
|
||||
fn toggle_data_collection(&self, cx: &mut App) {
|
||||
self.update(cx, |this, cx| this.toggle_data_collection(cx))
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
self.read(cx).is_enabled(buffer, cursor_position, cx)
|
||||
}
|
||||
|
||||
fn is_refreshing(&self, cx: &App) -> bool {
|
||||
self.read(cx).is_refreshing(cx)
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.refresh(buffer, cursor_position, debounce, cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn cycle(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
direction: Direction,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.cycle(buffer, cursor_position, direction, cx)
|
||||
})
|
||||
}
|
||||
|
||||
fn accept(&self, cx: &mut App) {
|
||||
self.update(cx, |this, cx| this.accept(cx))
|
||||
}
|
||||
|
||||
fn discard(&self, cx: &mut App) {
|
||||
self.update(cx, |this, cx| this.discard(cx))
|
||||
}
|
||||
|
||||
fn did_show(&self, cx: &mut App) {
|
||||
self.update(cx, |this, cx| this.did_show(cx))
|
||||
}
|
||||
|
||||
fn suggest(
|
||||
&self,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut App,
|
||||
) -> Option<EditPrediction> {
|
||||
self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
|
||||
/// edit is not a prefix of a predicted insertion.
|
||||
pub fn interpolate_edits(
|
||||
old_snapshot: &BufferSnapshot,
|
||||
new_snapshot: &BufferSnapshot,
|
||||
current_edits: &[(Range<Anchor>, Arc<str>)],
|
||||
) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
|
||||
let mut edits = Vec::new();
|
||||
|
||||
let mut model_edits = current_edits.iter().peekable();
|
||||
for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
|
||||
while let Some((model_old_range, _)) = model_edits.peek() {
|
||||
let model_old_range = model_old_range.to_offset(old_snapshot);
|
||||
if model_old_range.end < user_edit.old.start {
|
||||
let (model_old_range, model_new_text) = model_edits.next().unwrap();
|
||||
edits.push((model_old_range.clone(), model_new_text.clone()));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((model_old_range, model_new_text)) = model_edits.peek() {
|
||||
let model_old_offset_range = model_old_range.to_offset(old_snapshot);
|
||||
if user_edit.old == model_old_offset_range {
|
||||
let user_new_text = new_snapshot
|
||||
.text_for_range(user_edit.new.clone())
|
||||
.collect::<String>();
|
||||
|
||||
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
|
||||
if !model_suffix.is_empty() {
|
||||
let anchor = old_snapshot.anchor_after(user_edit.old.end);
|
||||
edits.push((anchor..anchor, model_suffix.into()));
|
||||
}
|
||||
|
||||
model_edits.next();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
|
||||
edits.extend(model_edits.cloned());
|
||||
|
||||
if edits.is_empty() { None } else { Some(edits) }
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "edit_prediction_button"
|
||||
name = "edit_prediction_ui"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
@@ -9,35 +9,43 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/edit_prediction_button.rs"
|
||||
path = "src/edit_prediction_ui.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
edit_prediction_types.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
markdown.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
settings.workspace = true
|
||||
supermaven.workspace = true
|
||||
telemetry.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
menu.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
@@ -1,16 +1,14 @@
|
||||
mod sweep_api_token_modal;
|
||||
|
||||
pub use sweep_api_token_modal::SweepApiKeyModal;
|
||||
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use cloud_llm_client::UsageLimit;
|
||||
use codestral::CodestralCompletionProvider;
|
||||
use codestral::CodestralEditPredictionDelegate;
|
||||
use copilot::{Copilot, Status};
|
||||
use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
use edit_prediction_types::EditPredictionDelegateHandle;
|
||||
use editor::{
|
||||
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle,
|
||||
@@ -44,7 +42,11 @@ use workspace::{
|
||||
notifications::NotificationId,
|
||||
};
|
||||
use zed_actions::OpenBrowser;
|
||||
use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
|
||||
|
||||
use crate::{
|
||||
RatePredictions, SweepApiKeyModal,
|
||||
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
|
||||
};
|
||||
|
||||
actions!(
|
||||
edit_prediction,
|
||||
@@ -67,7 +69,7 @@ pub struct EditPredictionButton {
|
||||
editor_focus_handle: Option<FocusHandle>,
|
||||
language: Option<Arc<Language>>,
|
||||
file: Option<Arc<dyn File>>,
|
||||
edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>,
|
||||
edit_prediction_provider: Option<Arc<dyn EditPredictionDelegateHandle>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
user_store: Entity<UserStore>,
|
||||
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
@@ -244,7 +246,7 @@ impl Render for EditPredictionButton {
|
||||
|
||||
EditPredictionProvider::Codestral => {
|
||||
let enabled = self.editor_enabled.unwrap_or(true);
|
||||
let has_api_key = CodestralCompletionProvider::has_api_key(cx);
|
||||
let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
|
||||
let fs = self.fs.clone();
|
||||
let this = cx.weak_entity();
|
||||
|
||||
@@ -317,16 +319,16 @@ impl Render for EditPredictionButton {
|
||||
);
|
||||
|
||||
let sweep_missing_token = is_sweep
|
||||
&& !zeta::Zeta::try_global(cx)
|
||||
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
|
||||
&& !edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
|
||||
|
||||
let zeta_icon = match (is_sweep, enabled) {
|
||||
let ep_icon = match (is_sweep, enabled) {
|
||||
(true, _) => IconName::SweepAi,
|
||||
(false, true) => IconName::ZedPredict,
|
||||
(false, false) => IconName::ZedPredictDisabled,
|
||||
};
|
||||
|
||||
if zeta::should_show_upsell_modal() {
|
||||
if edit_prediction::should_show_upsell_modal() {
|
||||
let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
|
||||
"Choose a Plan"
|
||||
} else {
|
||||
@@ -334,7 +336,7 @@ impl Render for EditPredictionButton {
|
||||
};
|
||||
|
||||
return div().child(
|
||||
IconButton::new("zed-predict-pending-button", zeta_icon)
|
||||
IconButton::new("zed-predict-pending-button", ep_icon)
|
||||
.shape(IconButtonShape::Square)
|
||||
.indicator(Indicator::dot().color(Color::Muted))
|
||||
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
|
||||
@@ -379,7 +381,7 @@ impl Render for EditPredictionButton {
|
||||
None
|
||||
};
|
||||
|
||||
let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
|
||||
let icon_button = IconButton::new("zed-predict-pending-button", ep_icon)
|
||||
.shape(IconButtonShape::Square)
|
||||
.when_some(indicator_color, |this, color| {
|
||||
this.indicator(Indicator::dot().color(color))
|
||||
@@ -419,13 +421,13 @@ impl Render for EditPredictionButton {
|
||||
|
||||
let this = cx.weak_entity();
|
||||
|
||||
let mut popover_menu = PopoverMenu::new("zeta")
|
||||
let mut popover_menu = PopoverMenu::new("edit-prediction")
|
||||
.when(user.is_some(), |popover_menu| {
|
||||
let this = this.clone();
|
||||
|
||||
popover_menu.menu(move |window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.build_zeta_context_menu(provider, window, cx)
|
||||
this.build_edit_prediction_context_menu(provider, window, cx)
|
||||
})
|
||||
.ok()
|
||||
})
|
||||
@@ -485,7 +487,7 @@ impl EditPredictionButton {
|
||||
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
||||
.detach();
|
||||
|
||||
CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx);
|
||||
CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx);
|
||||
|
||||
Self {
|
||||
editor_subscription: None,
|
||||
@@ -520,7 +522,7 @@ impl EditPredictionButton {
|
||||
}
|
||||
}
|
||||
|
||||
if CodestralCompletionProvider::has_api_key(cx) {
|
||||
if CodestralEditPredictionDelegate::has_api_key(cx) {
|
||||
providers.push(EditPredictionProvider::Codestral);
|
||||
}
|
||||
|
||||
@@ -599,8 +601,8 @@ impl EditPredictionButton {
|
||||
EditPredictionProvider::Experimental(
|
||||
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
|
||||
) => {
|
||||
let has_api_token = zeta::Zeta::try_global(cx)
|
||||
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
|
||||
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
|
||||
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
|
||||
|
||||
let should_open_modal = !has_api_token || is_current;
|
||||
|
||||
@@ -947,8 +949,8 @@ impl EditPredictionButton {
|
||||
)
|
||||
.context(editor_focus_handle)
|
||||
.when(
|
||||
cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>(),
|
||||
|this| this.action("Rate Completions", RateCompletions.boxed_clone()),
|
||||
cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
|
||||
|this| this.action("Rate Predictions", RatePredictions.boxed_clone()),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1016,7 +1018,7 @@ impl EditPredictionButton {
|
||||
})
|
||||
}
|
||||
|
||||
fn build_zeta_context_menu(
|
||||
fn build_edit_prediction_context_menu(
|
||||
&self,
|
||||
provider: EditPredictionProvider,
|
||||
window: &mut Window,
|
||||
@@ -1105,9 +1107,33 @@ impl EditPredictionButton {
|
||||
.separator();
|
||||
}
|
||||
|
||||
let menu = self.build_language_settings_menu(menu, window, cx);
|
||||
let menu = self.add_provider_switching_section(menu, provider, cx);
|
||||
menu = self.build_language_settings_menu(menu, window, cx);
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
let settings = all_language_settings(None, cx);
|
||||
let context_retrieval = settings.edit_predictions.use_context;
|
||||
menu = menu.separator().header("Context Retrieval").item(
|
||||
ContextMenuEntry::new("Enable Context Retrieval")
|
||||
.toggleable(IconPosition::Start, context_retrieval)
|
||||
.action(workspace::ToggleEditPrediction.boxed_clone())
|
||||
.handler({
|
||||
let fs = self.fs.clone();
|
||||
move |_, cx| {
|
||||
update_settings_file(fs.clone(), cx, move |settings, _| {
|
||||
settings
|
||||
.project
|
||||
.all_languages
|
||||
.features
|
||||
.get_or_insert_default()
|
||||
.experimental_edit_prediction_context_retrieval =
|
||||
Some(!context_retrieval)
|
||||
});
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
menu = self.add_provider_switching_section(menu, provider, cx);
|
||||
menu
|
||||
})
|
||||
}
|
||||
389
crates/edit_prediction_ui/src/edit_prediction_context_view.rs
Normal file
389
crates/edit_prediction_ui/src/edit_prediction_context_view.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::VecDeque,
|
||||
ops::Add,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use editor::{Editor, PathKey};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString,
|
||||
Styled as _, Task, TextAlign, Window, actions, div, pulsating_between,
|
||||
};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::Project;
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{
|
||||
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
|
||||
StyledTypography as _, h_flex, v_flex,
|
||||
};
|
||||
|
||||
use edit_prediction::{
|
||||
ContextRetrievalFinishedDebugEvent, ContextRetrievalStartedDebugEvent, DebugEvent,
|
||||
EditPredictionStore,
|
||||
};
|
||||
use workspace::Item;
|
||||
|
||||
pub struct EditPredictionContextView {
|
||||
empty_focus_handle: FocusHandle,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
runs: VecDeque<RetrievalRun>,
|
||||
current_ix: usize,
|
||||
_update_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RetrievalRun {
|
||||
editor: Entity<Editor>,
|
||||
started_at: Instant,
|
||||
metadata: Vec<(&'static str, SharedString)>,
|
||||
finished_at: Option<Instant>,
|
||||
}
|
||||
|
||||
actions!(
|
||||
dev,
|
||||
[
|
||||
/// Go to the previous context retrieval run
|
||||
EditPredictionContextGoBack,
|
||||
/// Go to the next context retrieval run
|
||||
EditPredictionContextGoForward
|
||||
]
|
||||
);
|
||||
|
||||
impl EditPredictionContextView {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
client: &Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
window: &mut gpui::Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let store = EditPredictionStore::global(client, user_store, cx);
|
||||
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
|
||||
let _update_task = cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.handle_store_event(event, window, cx)
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
|
||||
Self {
|
||||
empty_focus_handle: cx.focus_handle(),
|
||||
project,
|
||||
runs: VecDeque::new(),
|
||||
current_ix: 0,
|
||||
store,
|
||||
_update_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_store_event(
|
||||
&mut self,
|
||||
event: DebugEvent,
|
||||
window: &mut gpui::Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
DebugEvent::ContextRetrievalStarted(info) => {
|
||||
if info.project_entity_id == self.project.entity_id() {
|
||||
self.handle_context_retrieval_started(info, window, cx);
|
||||
}
|
||||
}
|
||||
DebugEvent::ContextRetrievalFinished(info) => {
|
||||
if info.project_entity_id == self.project.entity_id() {
|
||||
self.handle_context_retrieval_finished(info, window, cx);
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionRequested(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_context_retrieval_started(
|
||||
&mut self,
|
||||
info: ContextRetrievalStartedDebugEvent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self
|
||||
.runs
|
||||
.back()
|
||||
.is_some_and(|run| run.finished_at.is_none())
|
||||
{
|
||||
self.runs.pop_back();
|
||||
}
|
||||
|
||||
let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
|
||||
let editor = cx
|
||||
.new(|cx| Editor::for_multibuffer(multibuffer, Some(self.project.clone()), window, cx));
|
||||
|
||||
if self.runs.len() == 32 {
|
||||
self.runs.pop_front();
|
||||
}
|
||||
|
||||
self.runs.push_back(RetrievalRun {
|
||||
editor,
|
||||
started_at: info.timestamp,
|
||||
finished_at: None,
|
||||
metadata: Vec::new(),
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_context_retrieval_finished(
|
||||
&mut self,
|
||||
info: ContextRetrievalFinishedDebugEvent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(run) = self.runs.back_mut() else {
|
||||
return;
|
||||
};
|
||||
|
||||
run.finished_at = Some(info.timestamp);
|
||||
run.metadata = info.metadata;
|
||||
|
||||
let project = self.project.clone();
|
||||
let related_files = self
|
||||
.store
|
||||
.read(cx)
|
||||
.context_for_project(&self.project, cx)
|
||||
.to_vec();
|
||||
|
||||
let editor = run.editor.clone();
|
||||
let multibuffer = run.editor.read(cx).buffer().clone();
|
||||
|
||||
if self.current_ix + 2 == self.runs.len() {
|
||||
self.current_ix += 1;
|
||||
}
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = Vec::new();
|
||||
for related_file in related_files {
|
||||
let (buffer, point_ranges): (_, Vec<_>) =
|
||||
if let Some(buffer) = related_file.buffer.upgrade() {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
|
||||
(
|
||||
buffer,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(related_file.path.clone(), cx)
|
||||
})?
|
||||
.await?,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.point_range.clone())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
cx.update(|_, cx| {
|
||||
let path = PathKey::for_buffer(&buffer, cx);
|
||||
paths.push((path, buffer, point_ranges));
|
||||
})?;
|
||||
}
|
||||
|
||||
multibuffer.update(cx, |multibuffer, cx| {
|
||||
multibuffer.clear(cx);
|
||||
|
||||
for (path, buffer, ranges) in paths {
|
||||
multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx);
|
||||
}
|
||||
})?;
|
||||
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.move_to_beginning(&Default::default(), window, cx);
|
||||
})?;
|
||||
|
||||
this.update(cx, |_, cx| cx.notify())
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn handle_go_back(
|
||||
&mut self,
|
||||
_: &EditPredictionContextGoBack,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.current_ix = self.current_ix.saturating_sub(1);
|
||||
cx.focus_self(window);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_go_forward(
|
||||
&mut self,
|
||||
_: &EditPredictionContextGoForward,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.current_ix = self
|
||||
.current_ix
|
||||
.add(1)
|
||||
.min(self.runs.len().saturating_sub(1));
|
||||
cx.focus_self(window);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_informational_footer(
|
||||
&self,
|
||||
cx: &mut Context<'_, EditPredictionContextView>,
|
||||
) -> ui::Div {
|
||||
let run = &self.runs[self.current_ix];
|
||||
let new_run_started = self
|
||||
.runs
|
||||
.back()
|
||||
.map_or(false, |latest_run| latest_run.finished_at.is_none());
|
||||
|
||||
h_flex()
|
||||
.p_2()
|
||||
.w_full()
|
||||
.font_buffer(cx)
|
||||
.text_xs()
|
||||
.border_t_1()
|
||||
.gap_2()
|
||||
.child(v_flex().h_full().flex_1().child({
|
||||
let t0 = run.started_at;
|
||||
let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font();
|
||||
for (key, value) in &run.metadata {
|
||||
table = table.row([key.into_any_element(), value.clone().into_any_element()])
|
||||
}
|
||||
table = table.row([
|
||||
"Total Time".into_any_element(),
|
||||
format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis())
|
||||
.into_any_element(),
|
||||
]);
|
||||
table
|
||||
}))
|
||||
.child(
|
||||
v_flex().h_full().text_align(TextAlign::Right).child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.child(
|
||||
IconButton::new("go-back", IconName::ChevronLeft)
|
||||
.disabled(self.current_ix == 0 || self.runs.len() < 2)
|
||||
.tooltip(ui::Tooltip::for_action_title(
|
||||
"Go to previous run",
|
||||
&EditPredictionContextGoBack,
|
||||
))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.handle_go_back(&EditPredictionContextGoBack, window, cx);
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.child(format!("{}/{}", self.current_ix + 1, self.runs.len()))
|
||||
.map(|this| {
|
||||
if new_run_started {
|
||||
this.with_animation(
|
||||
"pulsating-count",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.opacity(delta),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
this.into_any_element()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("go-forward", IconName::ChevronRight)
|
||||
.disabled(self.current_ix + 1 == self.runs.len())
|
||||
.tooltip(ui::Tooltip::for_action_title(
|
||||
"Go to next run",
|
||||
&EditPredictionContextGoBack,
|
||||
))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.handle_go_forward(
|
||||
&EditPredictionContextGoForward,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for EditPredictionContextView {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.runs
|
||||
.get(self.current_ix)
|
||||
.map(|run| run.editor.read(cx).focus_handle(cx))
|
||||
.unwrap_or_else(|| self.empty_focus_handle.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<()> for EditPredictionContextView {}
|
||||
|
||||
impl Item for EditPredictionContextView {
|
||||
type Event = ();
|
||||
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Edit Prediction Context".into()
|
||||
}
|
||||
|
||||
fn buffer_kind(&self, _cx: &App) -> workspace::item::ItemBufferKind {
|
||||
workspace::item::ItemBufferKind::Multibuffer
|
||||
}
|
||||
|
||||
fn act_as_type<'a>(
|
||||
&'a self,
|
||||
type_id: TypeId,
|
||||
self_handle: &'a Entity<Self>,
|
||||
_: &'a App,
|
||||
) -> Option<gpui::AnyEntity> {
|
||||
if type_id == TypeId::of::<Self>() {
|
||||
Some(self_handle.clone().into())
|
||||
} else if type_id == TypeId::of::<Editor>() {
|
||||
Some(self.runs.get(self.current_ix)?.editor.clone().into())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl gpui::Render for EditPredictionContextView {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
|
||||
v_flex()
|
||||
.key_context("EditPredictionContext")
|
||||
.on_action(cx.listener(Self::handle_go_back))
|
||||
.on_action(cx.listener(Self::handle_go_forward))
|
||||
.size_full()
|
||||
.map(|this| {
|
||||
if self.runs.is_empty() {
|
||||
this.child(
|
||||
v_flex()
|
||||
.size_full()
|
||||
.justify_center()
|
||||
.items_center()
|
||||
.child("No retrieval runs yet"),
|
||||
)
|
||||
} else {
|
||||
this.child(self.runs[self.current_ix].editor.clone())
|
||||
.child(self.render_informational_footer(cx))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
128
crates/edit_prediction_ui/src/edit_prediction_ui.rs
Normal file
128
crates/edit_prediction_ui/src/edit_prediction_ui.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
mod edit_prediction_button;
|
||||
mod edit_prediction_context_view;
|
||||
mod rate_prediction_modal;
|
||||
mod sweep_api_token_modal;
|
||||
|
||||
use std::any::{Any as _, TypeId};
|
||||
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag};
|
||||
use edit_prediction_context_view::EditPredictionContextView;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use gpui::actions;
|
||||
use project::DisableAiSettings;
|
||||
use rate_prediction_modal::RatePredictionsModal;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use ui::{App, prelude::*};
|
||||
use workspace::{SplitDirection, Workspace};
|
||||
|
||||
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
|
||||
pub use sweep_api_token_modal::SweepApiKeyModal;
|
||||
|
||||
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
|
||||
|
||||
actions!(
|
||||
dev,
|
||||
[
|
||||
/// Opens the edit prediction context view.
|
||||
OpenEditPredictionContextView,
|
||||
]
|
||||
);
|
||||
|
||||
actions!(
|
||||
edit_prediction,
|
||||
[
|
||||
/// Opens the rate completions modal.
|
||||
RatePredictions,
|
||||
]
|
||||
);
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
feature_gate_predict_edits_actions(cx);
|
||||
|
||||
cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
|
||||
workspace.register_action(|workspace, _: &RatePredictions, window, cx| {
|
||||
if cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>() {
|
||||
RatePredictionsModal::toggle(workspace, window, cx);
|
||||
}
|
||||
});
|
||||
|
||||
workspace.register_action_renderer(|div, _, _, cx| {
|
||||
let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
|
||||
div.when(has_flag, |div| {
|
||||
div.on_action(cx.listener(
|
||||
move |workspace, _: &OpenEditPredictionContextView, window, cx| {
|
||||
let project = workspace.project();
|
||||
workspace.split_item(
|
||||
SplitDirection::Right,
|
||||
Box::new(cx.new(|cx| {
|
||||
EditPredictionContextView::new(
|
||||
project.clone(),
|
||||
workspace.client(),
|
||||
workspace.user_store(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
},
|
||||
))
|
||||
})
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn feature_gate_predict_edits_actions(cx: &mut App) {
|
||||
let rate_completion_action_types = [TypeId::of::<RatePredictions>()];
|
||||
let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
|
||||
let all_action_types = [
|
||||
TypeId::of::<RatePredictions>(),
|
||||
TypeId::of::<edit_prediction::ResetOnboarding>(),
|
||||
zed_actions::OpenZedPredictOnboarding.type_id(),
|
||||
TypeId::of::<edit_prediction::ClearHistory>(),
|
||||
TypeId::of::<rate_prediction_modal::ThumbsUpActivePrediction>(),
|
||||
TypeId::of::<rate_prediction_modal::ThumbsDownActivePrediction>(),
|
||||
TypeId::of::<rate_prediction_modal::NextEdit>(),
|
||||
TypeId::of::<rate_prediction_modal::PreviousEdit>(),
|
||||
];
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_action_types(&rate_completion_action_types);
|
||||
filter.hide_action_types(&reset_onboarding_action_types);
|
||||
filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
|
||||
});
|
||||
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
|
||||
let has_feature_flag = cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>();
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
if is_ai_disabled {
|
||||
filter.hide_action_types(&all_action_types);
|
||||
} else if has_feature_flag {
|
||||
filter.show_action_types(&rate_completion_action_types);
|
||||
} else {
|
||||
filter.hide_action_types(&rate_completion_action_types);
|
||||
}
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.observe_flag::<PredictEditsRatePredictionsFeatureFlag, _>(move |is_enabled, cx| {
|
||||
if !DisableAiSettings::get_global(cx).disable_ai {
|
||||
if is_enabled {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.show_action_types(&rate_completion_action_types);
|
||||
});
|
||||
} else {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_action_types(&rate_completion_action_types);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::{EditPrediction, EditPredictionRating, Zeta};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use cloud_zeta2_prompt::write_codeblock;
|
||||
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
|
||||
use editor::{Editor, ExcerptRange, MultiBuffer};
|
||||
use feature_flags::FeatureFlag;
|
||||
use gpui::{
|
||||
App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
|
||||
@@ -9,9 +10,7 @@ use gpui::{
|
||||
use language::{LanguageRegistry, Point, language_settings};
|
||||
use markdown::{Markdown, MarkdownStyle};
|
||||
use settings::Settings as _;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt::Write, sync::Arc, time::Duration};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*};
|
||||
use workspace::{ModalView, Workspace};
|
||||
@@ -34,8 +33,14 @@ actions!(
|
||||
]
|
||||
);
|
||||
|
||||
pub struct PredictEditsRatePredictionsFeatureFlag;
|
||||
|
||||
impl FeatureFlag for PredictEditsRatePredictionsFeatureFlag {
|
||||
const NAME: &'static str = "predict-edits-rate-completions";
|
||||
}
|
||||
|
||||
pub struct RatePredictionsModal {
|
||||
zeta: Entity<Zeta>,
|
||||
ep_store: Entity<EditPredictionStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
active_prediction: Option<ActivePrediction>,
|
||||
selected_index: usize,
|
||||
@@ -68,10 +73,10 @@ impl RatePredictionView {
|
||||
|
||||
impl RatePredictionsModal {
|
||||
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
|
||||
if let Some(zeta) = Zeta::try_global(cx) {
|
||||
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
|
||||
let language_registry = workspace.app_state().languages.clone();
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
RatePredictionsModal::new(zeta, language_registry, window, cx)
|
||||
RatePredictionsModal::new(ep_store, language_registry, window, cx)
|
||||
});
|
||||
|
||||
telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction");
|
||||
@@ -79,15 +84,15 @@ impl RatePredictionsModal {
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
zeta: Entity<Zeta>,
|
||||
ep_store: Entity<EditPredictionStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let subscription = cx.observe(&zeta, |_, _, cx| cx.notify());
|
||||
let subscription = cx.observe(&ep_store, |_, _, cx| cx.notify());
|
||||
|
||||
Self {
|
||||
zeta,
|
||||
ep_store,
|
||||
language_registry,
|
||||
selected_index: 0,
|
||||
focus_handle: cx.focus_handle(),
|
||||
@@ -113,7 +118,7 @@ impl RatePredictionsModal {
|
||||
self.selected_index += 1;
|
||||
self.selected_index = usize::min(
|
||||
self.selected_index,
|
||||
self.zeta.read(cx).shown_predictions().count(),
|
||||
self.ep_store.read(cx).shown_predictions().count(),
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
@@ -130,7 +135,7 @@ impl RatePredictionsModal {
|
||||
|
||||
fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let next_index = self
|
||||
.zeta
|
||||
.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.skip(self.selected_index)
|
||||
@@ -146,11 +151,11 @@ impl RatePredictionsModal {
|
||||
}
|
||||
|
||||
fn select_prev_edit(&mut self, _: &PreviousEdit, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let zeta = self.zeta.read(cx);
|
||||
let completions_len = zeta.shown_completions_len();
|
||||
let ep_store = self.ep_store.read(cx);
|
||||
let completions_len = ep_store.shown_completions_len();
|
||||
|
||||
let prev_index = self
|
||||
.zeta
|
||||
.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.rev()
|
||||
@@ -173,7 +178,7 @@ impl RatePredictionsModal {
|
||||
}
|
||||
|
||||
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.selected_index = self.zeta.read(cx).shown_completions_len() - 1;
|
||||
self.selected_index = self.ep_store.read(cx).shown_completions_len() - 1;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -183,9 +188,9 @@ impl RatePredictionsModal {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
self.ep_store.update(cx, |ep_store, cx| {
|
||||
if let Some(active) = &self.active_prediction {
|
||||
zeta.rate_prediction(
|
||||
ep_store.rate_prediction(
|
||||
&active.prediction,
|
||||
EditPredictionRating::Positive,
|
||||
active.feedback_editor.read(cx).text(cx),
|
||||
@@ -216,8 +221,8 @@ impl RatePredictionsModal {
|
||||
return;
|
||||
}
|
||||
|
||||
self.zeta.update(cx, |zeta, cx| {
|
||||
zeta.rate_prediction(
|
||||
self.ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.rate_prediction(
|
||||
&active.prediction,
|
||||
EditPredictionRating::Negative,
|
||||
active.feedback_editor.read(cx).text(cx),
|
||||
@@ -254,7 +259,7 @@ impl RatePredictionsModal {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let completion = self
|
||||
.zeta
|
||||
.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.skip(self.selected_index)
|
||||
@@ -267,7 +272,7 @@ impl RatePredictionsModal {
|
||||
|
||||
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let completion = self
|
||||
.zeta
|
||||
.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.skip(self.selected_index)
|
||||
@@ -288,7 +293,7 @@ impl RatePredictionsModal {
|
||||
// Avoid resetting completion rating if it's already selected.
|
||||
if let Some(prediction) = prediction {
|
||||
self.selected_index = self
|
||||
.zeta
|
||||
.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.enumerate()
|
||||
@@ -376,7 +381,7 @@ impl RatePredictionsModal {
|
||||
&included_file.path,
|
||||
&included_file.excerpts,
|
||||
if included_file.path == prediction.inputs.cursor_path {
|
||||
cursor_insertions
|
||||
cursor_insertions.as_slice()
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
@@ -564,7 +569,7 @@ impl RatePredictionsModal {
|
||||
let border_color = cx.theme().colors().border;
|
||||
let bg_color = cx.theme().colors().editor_background;
|
||||
|
||||
let rated = self.zeta.read(cx).is_prediction_rated(&completion_id);
|
||||
let rated = self.ep_store.read(cx).is_prediction_rated(&completion_id);
|
||||
let feedback_empty = active_prediction
|
||||
.feedback_editor
|
||||
.read(cx)
|
||||
@@ -715,7 +720,7 @@ impl RatePredictionsModal {
|
||||
}
|
||||
|
||||
fn render_shown_completions(&self, cx: &Context<Self>) -> impl Iterator<Item = ListItem> {
|
||||
self.zeta
|
||||
self.ep_store
|
||||
.read(cx)
|
||||
.shown_predictions()
|
||||
.cloned()
|
||||
@@ -725,7 +730,7 @@ impl RatePredictionsModal {
|
||||
.active_prediction
|
||||
.as_ref()
|
||||
.is_some_and(|selected| selected.prediction.id == completion.id);
|
||||
let rated = self.zeta.read(cx).is_prediction_rated(&completion.id);
|
||||
let rated = self.ep_store.read(cx).is_prediction_rated(&completion.id);
|
||||
|
||||
let (icon_name, icon_color, tooltip_text) =
|
||||
match (rated, completion.edits.is_empty()) {
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user