Compare commits
53 Commits
fix-max-im
...
test-drive
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c240f876b1 | ||
|
|
b076ff99ef | ||
|
|
36c173e3e2 | ||
|
|
6e9c6c5684 | ||
|
|
42f788185a | ||
|
|
a5b2428897 | ||
|
|
0629804390 | ||
|
|
3151b5efc1 | ||
|
|
782fbfad90 | ||
|
|
2caa19214b | ||
|
|
bff5d85ff4 | ||
|
|
abe5d523e1 | ||
|
|
8fb3199a84 | ||
|
|
0d809c21ba | ||
|
|
93b1e95a5d | ||
|
|
49bc2e61da | ||
|
|
9a4bcd11a2 | ||
|
|
2ee5bedfa9 | ||
|
|
d497f52e17 | ||
|
|
f022a13091 | ||
|
|
c74ecb4654 | ||
|
|
7609ca7a8d | ||
|
|
32906bfa7c | ||
|
|
5fafab6e52 | ||
|
|
8573b3a84b | ||
|
|
5e70235794 | ||
|
|
9e9192f6a3 | ||
|
|
936972d9b0 | ||
|
|
e9533423db | ||
|
|
ba480295c1 | ||
|
|
9106f4495b | ||
|
|
1feb1296fe | ||
|
|
582a247922 | ||
|
|
c2881a4537 | ||
|
|
b4744750da | ||
|
|
6edc255158 | ||
|
|
a96a1b1339 | ||
|
|
73cee468ed | ||
|
|
1f06615da2 | ||
|
|
c1773f7281 | ||
|
|
6c6b1ba3bc | ||
|
|
a23d9328ce | ||
|
|
5796a2663b | ||
|
|
447eb8e1c9 | ||
|
|
e434117018 | ||
|
|
36271b79b3 | ||
|
|
41644a53cc | ||
|
|
08a9c4af09 | ||
|
|
3187f28405 | ||
|
|
101f3b100f | ||
|
|
39c8b7bf5f | ||
|
|
08b41252f6 | ||
|
|
152bbca238 |
63
Cargo.lock
generated
63
Cargo.lock
generated
@@ -107,6 +107,39 @@ dependencies = [
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_settings"
|
||||
version = "0.1.0"
|
||||
@@ -1911,7 +1944,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -4133,7 +4165,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "dap-types"
|
||||
version = "0.0.1"
|
||||
source = "git+https://github.com/zed-industries/dap-types?rev=b40956a7f4d1939da67429d941389ee306a3a308#b40956a7f4d1939da67429d941389ee306a3a308"
|
||||
source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9"
|
||||
dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -4814,6 +4846,7 @@ dependencies = [
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"rpc",
|
||||
"schemars",
|
||||
@@ -8847,6 +8880,7 @@ dependencies = [
|
||||
"http_client",
|
||||
"imara-diff",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"lsp",
|
||||
@@ -8945,8 +8979,10 @@ dependencies = [
|
||||
"aws-credential-types",
|
||||
"aws_http_client",
|
||||
"bedrock",
|
||||
"chrono",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
@@ -14053,12 +14089,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "schemars"
|
||||
version = "0.8.22"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
|
||||
checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984"
|
||||
dependencies = [
|
||||
"dyn-clone",
|
||||
"indexmap",
|
||||
"ref-cast",
|
||||
"schemars_derive",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -14066,9 +14103,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "schemars_derive"
|
||||
version = "0.8.22"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
|
||||
checksum = "6ca9fcb757952f8e8629b9ab066fc62da523c46c2b247b1708a3be06dd82530b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -14567,13 +14604,22 @@ dependencies = [
|
||||
name = "settings_ui"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"collections",
|
||||
"command_palette",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
"db",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"log",
|
||||
"menu",
|
||||
"paths",
|
||||
"project",
|
||||
"schemars",
|
||||
"search",
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
@@ -16010,6 +16056,7 @@ dependencies = [
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indexmap",
|
||||
"inventory",
|
||||
"log",
|
||||
"palette",
|
||||
"parking_lot",
|
||||
@@ -20127,9 +20174,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.8.4"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
|
||||
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
|
||||
@@ -4,6 +4,7 @@ members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/agent_ui",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/agent_settings",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
@@ -444,7 +445,7 @@ core-video = { version = "0.4.3", features = ["metal"] }
|
||||
cpal = "0.16"
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
ctor = "0.4.0"
|
||||
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "b40956a7f4d1939da67429d941389ee306a3a308" }
|
||||
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" }
|
||||
dashmap = "6.0"
|
||||
derive_more = "0.99.17"
|
||||
dirs = "4.0"
|
||||
@@ -540,7 +541,7 @@ rustc-hash = "2.1.0"
|
||||
rustls = { version = "0.23.26" }
|
||||
rustls-platform-verifier = "0.5.0"
|
||||
scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false }
|
||||
schemars = { version = "0.8", features = ["impl_json_schema", "indexmap2"] }
|
||||
schemars = { version = "1.0", features = ["indexmap2"] }
|
||||
semver = "1.0"
|
||||
serde = { version = "1.0", features = ["derive", "rc"] }
|
||||
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
|
||||
@@ -625,7 +626,7 @@ wasmtime = { version = "29", default-features = false, features = [
|
||||
wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.8.4"
|
||||
zed_llm_client = "0.8.5"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
|
||||
@@ -1067,5 +1067,12 @@
|
||||
"ctrl-tab": "pane::ActivateNextItem",
|
||||
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "KeymapEditor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-f": "search::FocusSearch"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1167,5 +1167,12 @@
|
||||
"ctrl-tab": "pane::ActivateNextItem",
|
||||
"ctrl-shift-tab": "pane::ActivatePreviousItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "KeymapEditor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-f": "search::FocusSearch"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -111,7 +111,7 @@ mod tests {
|
||||
use assistant_tool::ToolRegistry;
|
||||
use collections::IndexMap;
|
||||
use gpui::SharedString;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
||||
@@ -819,134 +819,6 @@ impl LoadedContext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_to_request_message_with_model(
|
||||
&self,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
model: &Arc<dyn language_model::LanguageModel>,
|
||||
) {
|
||||
if !self.text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(self.text.to_string()));
|
||||
}
|
||||
|
||||
if !self.images.is_empty() {
|
||||
let max_image_size = model.max_image_size();
|
||||
let mut images_added = false;
|
||||
|
||||
for image in &self.images {
|
||||
let image_size = image.len() as u64;
|
||||
if image_size > max_image_size {
|
||||
if max_image_size == 0 {
|
||||
log::warn!(
|
||||
"Skipping image attachment: model {:?} does not support images",
|
||||
model.name()
|
||||
);
|
||||
} else {
|
||||
log::warn!(
|
||||
"Skipping image attachment: size {} bytes exceeds model {:?} limit of {} bytes",
|
||||
image_size,
|
||||
model.name(),
|
||||
max_image_size
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Some providers only support image parts after an initial text part
|
||||
if !images_added && request_message.content.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text("Images attached by user:".to_string()));
|
||||
}
|
||||
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Image(image.clone()));
|
||||
images_added = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks images against model size limits and returns information about rejected images
|
||||
pub fn check_image_size_limits(
|
||||
&self,
|
||||
model: &Arc<dyn language_model::LanguageModel>,
|
||||
) -> Vec<RejectedImage> {
|
||||
let mut rejected_images = Vec::new();
|
||||
|
||||
if !self.images.is_empty() {
|
||||
let max_image_size = model.max_image_size();
|
||||
|
||||
for image in &self.images {
|
||||
let image_size = image.len() as u64;
|
||||
if image_size > max_image_size {
|
||||
rejected_images.push(RejectedImage {
|
||||
size: image_size,
|
||||
max_size: max_image_size,
|
||||
model_name: model.name().0.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rejected_images
|
||||
}
|
||||
|
||||
pub fn add_to_request_message_with_validation<F>(
|
||||
&self,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
model: &Arc<dyn language_model::LanguageModel>,
|
||||
mut on_image_rejected: F,
|
||||
) where
|
||||
F: FnMut(u64, u64, &str),
|
||||
{
|
||||
if !self.text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(self.text.to_string()));
|
||||
}
|
||||
|
||||
if !self.images.is_empty() {
|
||||
let max_image_size = model.max_image_size();
|
||||
let mut images_added = false;
|
||||
|
||||
for image in &self.images {
|
||||
let image_size = image.len() as u64;
|
||||
if image_size > max_image_size {
|
||||
on_image_rejected(image_size, max_image_size, &model.name().0);
|
||||
|
||||
if max_image_size == 0 {
|
||||
log::warn!(
|
||||
"Skipping image attachment: model {:?} does not support images",
|
||||
model.name()
|
||||
);
|
||||
} else {
|
||||
log::warn!(
|
||||
"Skipping image attachment: size {} bytes exceeds model {:?} limit of {} bytes",
|
||||
image_size,
|
||||
model.name(),
|
||||
max_image_size
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Some providers only support image parts after an initial text part
|
||||
if !images_added && request_message.content.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text("Images attached by user:".to_string()));
|
||||
}
|
||||
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Image(image.clone()));
|
||||
images_added = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads and formats a collection of contexts.
|
||||
@@ -1240,18 +1112,10 @@ impl Hash for AgentContextKey {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RejectedImage {
|
||||
pub size: u64,
|
||||
pub max_size: u64,
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AsyncApp, TestAppContext};
|
||||
use language_model::{LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName};
|
||||
use gpui::TestAppContext;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
@@ -1358,484 +1222,4 @@ mod tests {
|
||||
})
|
||||
.expect("Should have found a file context")
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_image_size_limit_filtering(_cx: &mut TestAppContext) {
|
||||
use futures::stream::BoxStream;
|
||||
use gpui::{AsyncApp, DevicePixels, SharedString};
|
||||
use language_model::{
|
||||
LanguageModelId, LanguageModelImage, LanguageModelName, LanguageModelProviderId,
|
||||
LanguageModelProviderName, Role,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
// Create a mock image that's 10 bytes
|
||||
let small_image = LanguageModelImage {
|
||||
source: "small_data".into(),
|
||||
size: gpui::size(DevicePixels(10), DevicePixels(10)),
|
||||
};
|
||||
|
||||
// Create a mock image that's 1MB
|
||||
let large_image_source = "x".repeat(1_048_576);
|
||||
let large_image = LanguageModelImage {
|
||||
source: large_image_source.into(),
|
||||
size: gpui::size(DevicePixels(1024), DevicePixels(1024)),
|
||||
};
|
||||
|
||||
let loaded_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: "Some text".to_string(),
|
||||
images: vec![small_image.clone(), large_image.clone()],
|
||||
};
|
||||
|
||||
// Test with a model that supports images with 500KB limit
|
||||
struct TestModel500KB;
|
||||
impl language_model::LanguageModel for TestModel500KB {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(SharedString::from("test-500kb"))
|
||||
}
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(SharedString::from("Test Model 500KB"))
|
||||
}
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(SharedString::from("test"))
|
||||
}
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(SharedString::from("Test Provider"))
|
||||
}
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_tool_choice(&self, _: language_model::LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_image_size(&self) -> u64 {
|
||||
512_000
|
||||
} // 500KB
|
||||
fn telemetry_id(&self) -> String {
|
||||
"test-500kb".to_string()
|
||||
}
|
||||
fn max_token_count(&self) -> u64 {
|
||||
100_000
|
||||
}
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<u64>> {
|
||||
Box::pin(async { Ok(0) })
|
||||
}
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &AsyncApp,
|
||||
) -> futures::future::BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<
|
||||
language_model::LanguageModelCompletionEvent,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
use language_model::LanguageModelCompletionError;
|
||||
Box::pin(async {
|
||||
Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
|
||||
"Not implemented"
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let model_500kb: Arc<dyn language_model::LanguageModel> = Arc::new(TestModel500KB);
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
loaded_context.add_to_request_message_with_model(&mut request_message, &model_500kb);
|
||||
|
||||
// Should have text and only the small image
|
||||
assert_eq!(request_message.content.len(), 2); // text + small image
|
||||
assert!(
|
||||
matches!(&request_message.content[0], MessageContent::Text(text) if text == "Some text")
|
||||
);
|
||||
assert!(matches!(
|
||||
&request_message.content[1],
|
||||
MessageContent::Image(_)
|
||||
));
|
||||
|
||||
// Test with a model that doesn't support images
|
||||
struct TestModelNoImages;
|
||||
impl language_model::LanguageModel for TestModelNoImages {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(SharedString::from("test-no-images"))
|
||||
}
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(SharedString::from("Test Model No Images"))
|
||||
}
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(SharedString::from("test"))
|
||||
}
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(SharedString::from("Test Provider"))
|
||||
}
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_tool_choice(&self, _: language_model::LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0
|
||||
} // No image support
|
||||
fn telemetry_id(&self) -> String {
|
||||
"test-no-images".to_string()
|
||||
}
|
||||
fn max_token_count(&self) -> u64 {
|
||||
100_000
|
||||
}
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<u64>> {
|
||||
Box::pin(async { Ok(0) })
|
||||
}
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &AsyncApp,
|
||||
) -> futures::future::BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<
|
||||
language_model::LanguageModelCompletionEvent,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
use language_model::LanguageModelCompletionError;
|
||||
Box::pin(async {
|
||||
Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
|
||||
"Not implemented"
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let model_no_images: Arc<dyn language_model::LanguageModel> = Arc::new(TestModelNoImages);
|
||||
let mut request_message_no_images = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
loaded_context
|
||||
.add_to_request_message_with_model(&mut request_message_no_images, &model_no_images);
|
||||
|
||||
// Should have only text, no images
|
||||
assert_eq!(request_message_no_images.content.len(), 1);
|
||||
assert!(
|
||||
matches!(&request_message_no_images.content[0], MessageContent::Text(text) if text == "Some text")
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_check_image_size_limits() {
|
||||
use gpui::DevicePixels;
|
||||
use language_model::LanguageModelImage;
|
||||
|
||||
// Create test images of various sizes
|
||||
let tiny_image = LanguageModelImage {
|
||||
source: "tiny".into(),
|
||||
size: gpui::size(DevicePixels(10), DevicePixels(10)),
|
||||
};
|
||||
|
||||
let small_image = LanguageModelImage {
|
||||
source: "x".repeat(100_000).into(), // 100KB
|
||||
size: gpui::size(DevicePixels(100), DevicePixels(100)),
|
||||
};
|
||||
|
||||
let medium_image = LanguageModelImage {
|
||||
source: "x".repeat(500_000).into(), // 500KB
|
||||
size: gpui::size(DevicePixels(500), DevicePixels(500)),
|
||||
};
|
||||
|
||||
let large_image = LanguageModelImage {
|
||||
source: "x".repeat(1_048_576).into(), // 1MB
|
||||
size: gpui::size(DevicePixels(1024), DevicePixels(1024)),
|
||||
};
|
||||
|
||||
let huge_image = LanguageModelImage {
|
||||
source: "x".repeat(5_242_880).into(), // 5MB
|
||||
size: gpui::size(DevicePixels(2048), DevicePixels(2048)),
|
||||
};
|
||||
|
||||
// Test with model that has 1MB limit
|
||||
let model_1mb = Arc::new(TestModel1MB);
|
||||
let loaded_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: String::new(),
|
||||
images: vec![
|
||||
tiny_image.clone(),
|
||||
small_image.clone(),
|
||||
medium_image.clone(),
|
||||
large_image.clone(),
|
||||
huge_image.clone(),
|
||||
],
|
||||
};
|
||||
|
||||
let rejected = loaded_context.check_image_size_limits(
|
||||
&(model_1mb.clone() as Arc<dyn language_model::LanguageModel>),
|
||||
);
|
||||
assert_eq!(rejected.len(), 1);
|
||||
assert_eq!(rejected[0].size, 5_242_880);
|
||||
assert_eq!(rejected[0].max_size, 1_048_576);
|
||||
assert_eq!(rejected[0].model_name, "Test Model 1MB");
|
||||
|
||||
// Test with model that doesn't support images
|
||||
let model_no_images = Arc::new(TestModelNoImages);
|
||||
let rejected = loaded_context.check_image_size_limits(
|
||||
&(model_no_images.clone() as Arc<dyn language_model::LanguageModel>),
|
||||
);
|
||||
assert_eq!(rejected.len(), 5); // All images rejected
|
||||
for (_i, rejected_image) in rejected.iter().enumerate() {
|
||||
assert_eq!(rejected_image.max_size, 0);
|
||||
assert_eq!(rejected_image.model_name, "Test Model No Images");
|
||||
}
|
||||
|
||||
// Test with empty image list
|
||||
let empty_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: String::new(),
|
||||
images: vec![],
|
||||
};
|
||||
let rejected = empty_context.check_image_size_limits(
|
||||
&(model_1mb.clone() as Arc<dyn language_model::LanguageModel>),
|
||||
);
|
||||
assert!(rejected.is_empty());
|
||||
|
||||
// Test with all images within limit
|
||||
let small_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: String::new(),
|
||||
images: vec![tiny_image.clone(), small_image.clone()],
|
||||
};
|
||||
let rejected = small_context
|
||||
.check_image_size_limits(&(model_1mb as Arc<dyn language_model::LanguageModel>));
|
||||
assert!(rejected.is_empty());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_add_to_request_message_with_validation() {
|
||||
use gpui::DevicePixels;
|
||||
use language_model::{LanguageModelImage, MessageContent, Role};
|
||||
|
||||
let small_image = LanguageModelImage {
|
||||
source: "small".into(),
|
||||
size: gpui::size(DevicePixels(10), DevicePixels(10)),
|
||||
};
|
||||
|
||||
let large_image = LanguageModelImage {
|
||||
source: "x".repeat(2_097_152).into(), // 2MB
|
||||
size: gpui::size(DevicePixels(1024), DevicePixels(1024)),
|
||||
};
|
||||
|
||||
let loaded_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: "Test message".to_string(),
|
||||
images: vec![small_image.clone(), large_image.clone()],
|
||||
};
|
||||
|
||||
let model = Arc::new(TestModel1MB);
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
let mut rejected_count = 0;
|
||||
let mut rejected_sizes = Vec::new();
|
||||
let mut rejected_model_names = Vec::new();
|
||||
|
||||
loaded_context.add_to_request_message_with_validation(
|
||||
&mut request_message,
|
||||
&(model.clone() as Arc<dyn language_model::LanguageModel>),
|
||||
|size, max_size, model_name| {
|
||||
rejected_count += 1;
|
||||
rejected_sizes.push((size, max_size));
|
||||
rejected_model_names.push(model_name.to_string());
|
||||
},
|
||||
);
|
||||
|
||||
// Verify callback was called for the large image
|
||||
assert_eq!(rejected_count, 1);
|
||||
assert_eq!(rejected_sizes[0], (2_097_152, 1_048_576));
|
||||
assert_eq!(rejected_model_names[0], "Test Model 1MB");
|
||||
|
||||
// Verify the request message contains text and only the small image
|
||||
assert_eq!(request_message.content.len(), 2); // text + small image
|
||||
assert!(
|
||||
matches!(&request_message.content[0], MessageContent::Text(text) if text == "Test message")
|
||||
);
|
||||
assert!(matches!(
|
||||
&request_message.content[1],
|
||||
MessageContent::Image(_)
|
||||
));
|
||||
}
|
||||
|
||||
// Helper test models
|
||||
struct TestModel1MB;
|
||||
impl language_model::LanguageModel for TestModel1MB {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(SharedString::from("test-1mb"))
|
||||
}
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(SharedString::from("Test Model 1MB"))
|
||||
}
|
||||
fn provider_id(&self) -> language_model::LanguageModelProviderId {
|
||||
language_model::LanguageModelProviderId(SharedString::from("test"))
|
||||
}
|
||||
fn provider_name(&self) -> language_model::LanguageModelProviderName {
|
||||
language_model::LanguageModelProviderName(SharedString::from("Test Provider"))
|
||||
}
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_tool_choice(&self, _: language_model::LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_image_size(&self) -> u64 {
|
||||
1_048_576 // 1MB
|
||||
}
|
||||
fn telemetry_id(&self) -> String {
|
||||
"test-1mb".to_string()
|
||||
}
|
||||
fn max_token_count(&self) -> u64 {
|
||||
100_000
|
||||
}
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
Some(4096)
|
||||
}
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
Some(LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: 0,
|
||||
should_speculate: false,
|
||||
min_total_token: 1024,
|
||||
})
|
||||
}
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<u64>> {
|
||||
Box::pin(async { Ok(0) })
|
||||
}
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &AsyncApp,
|
||||
) -> futures::future::BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<
|
||||
language_model::LanguageModelCompletionEvent,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
use language_model::LanguageModelCompletionError;
|
||||
Box::pin(async {
|
||||
Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
|
||||
"Not implemented"
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct TestModelNoImages;
|
||||
impl language_model::LanguageModel for TestModelNoImages {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(SharedString::from("test-no-images"))
|
||||
}
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(SharedString::from("Test Model No Images"))
|
||||
}
|
||||
fn provider_id(&self) -> language_model::LanguageModelProviderId {
|
||||
language_model::LanguageModelProviderId(SharedString::from("test"))
|
||||
}
|
||||
fn provider_name(&self) -> language_model::LanguageModelProviderName {
|
||||
language_model::LanguageModelProviderName(SharedString::from("Test Provider"))
|
||||
}
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_tool_choice(&self, _: language_model::LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0 // No image support
|
||||
}
|
||||
fn telemetry_id(&self) -> String {
|
||||
"test-no-images".to_string()
|
||||
}
|
||||
fn max_token_count(&self) -> u64 {
|
||||
100_000
|
||||
}
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
Some(4096)
|
||||
}
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
Some(LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: 0,
|
||||
should_speculate: false,
|
||||
min_total_token: 1024,
|
||||
})
|
||||
}
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<u64>> {
|
||||
Box::pin(async { Ok(0) })
|
||||
}
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &AsyncApp,
|
||||
) -> futures::future::BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<
|
||||
language_model::LanguageModelCompletionEvent,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
use language_model::LanguageModelCompletionError;
|
||||
Box::pin(async {
|
||||
Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
|
||||
"Not implemented"
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,11 +23,10 @@ use gpui::{
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
|
||||
TokenUsage,
|
||||
LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
|
||||
Role, SelectedModel, StopReason, TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::{
|
||||
@@ -1338,7 +1337,7 @@ impl Thread {
|
||||
|
||||
message
|
||||
.loaded_context
|
||||
.add_to_request_message_with_model(&mut request_message, &model);
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
@@ -1531,82 +1530,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(error) => {
|
||||
match error {
|
||||
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
||||
anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
|
||||
}
|
||||
LanguageModelCompletionError::Overloaded => {
|
||||
anyhow::bail!(LanguageModelKnownError::Overloaded);
|
||||
}
|
||||
LanguageModelCompletionError::ApiInternalServerError =>{
|
||||
anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
|
||||
}
|
||||
LanguageModelCompletionError::PromptTooLarge { tokens } => {
|
||||
let tokens = tokens.unwrap_or_else(|| {
|
||||
// We didn't get an exact token count from the API, so fall back on our estimate.
|
||||
thread.total_token_usage()
|
||||
.map(|usage| usage.total)
|
||||
.unwrap_or(0)
|
||||
// We know the context window was exceeded in practice, so if our estimate was
|
||||
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||
.max(model.max_token_count().saturating_add(1))
|
||||
});
|
||||
|
||||
anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
|
||||
}
|
||||
LanguageModelCompletionError::ApiReadResponseError(io_error) => {
|
||||
anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
|
||||
}
|
||||
LanguageModelCompletionError::UnknownResponseFormat(error) => {
|
||||
anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
|
||||
}
|
||||
LanguageModelCompletionError::HttpResponseError { status, ref body } => {
|
||||
if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
|
||||
anyhow::bail!(known_error);
|
||||
} else {
|
||||
return Err(error.into());
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionError::DeserializeResponse(error) => {
|
||||
anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
|
||||
}
|
||||
LanguageModelCompletionError::BadInputJson {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
} => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
// These are all errors we can't automatically attempt to recover from (e.g. by retrying)
|
||||
err @ LanguageModelCompletionError::BadRequestFormat |
|
||||
err @ LanguageModelCompletionError::AuthenticationError |
|
||||
err @ LanguageModelCompletionError::PermissionError |
|
||||
err @ LanguageModelCompletionError::ApiEndpointNotFound |
|
||||
err @ LanguageModelCompletionError::SerializeRequest(_) |
|
||||
err @ LanguageModelCompletionError::BuildRequestBody(_) |
|
||||
err @ LanguageModelCompletionError::HttpSend(_) => {
|
||||
anyhow::bail!(err);
|
||||
}
|
||||
LanguageModelCompletionError::Other(error) => {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
request_assistant_message_id =
|
||||
Some(thread.insert_assistant_message(
|
||||
@@ -1683,9 +1607,7 @@ impl Thread {
|
||||
};
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::RedactedThinking {
|
||||
data
|
||||
} => {
|
||||
LanguageModelCompletionEvent::RedactedThinking { data } => {
|
||||
thread.received_chunk();
|
||||
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
@@ -1734,6 +1656,21 @@ impl Thread {
|
||||
});
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
} => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||
if let Some(completion) = thread
|
||||
.pending_completions
|
||||
@@ -1741,23 +1678,34 @@ impl Thread {
|
||||
.find(|completion| completion.id == pending_completion_id)
|
||||
{
|
||||
match status_update {
|
||||
CompletionRequestStatus::Queued {
|
||||
position,
|
||||
} => {
|
||||
completion.queue_state = QueueState::Queued { position };
|
||||
CompletionRequestStatus::Queued { position } => {
|
||||
completion.queue_state =
|
||||
QueueState::Queued { position };
|
||||
}
|
||||
CompletionRequestStatus::Started => {
|
||||
completion.queue_state = QueueState::Started;
|
||||
completion.queue_state = QueueState::Started;
|
||||
}
|
||||
CompletionRequestStatus::Failed {
|
||||
code, message, request_id
|
||||
code,
|
||||
message,
|
||||
request_id: _,
|
||||
retry_after,
|
||||
} => {
|
||||
anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
|
||||
return Err(
|
||||
LanguageModelCompletionError::from_cloud_failure(
|
||||
model.upstream_provider_name(),
|
||||
code,
|
||||
message,
|
||||
retry_after.map(Duration::from_secs_f64),
|
||||
),
|
||||
);
|
||||
}
|
||||
CompletionRequestStatus::UsageUpdated {
|
||||
amount, limit
|
||||
} => {
|
||||
thread.update_model_request_usage(amount as u32, limit, cx);
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||
thread.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
CompletionRequestStatus::ToolUseLimitReached => {
|
||||
thread.tool_use_limit_reached = true;
|
||||
@@ -1808,10 +1756,11 @@ impl Thread {
|
||||
Ok(stop_reason) => {
|
||||
match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
|
||||
let tool_uses =
|
||||
thread.use_pending_tools(window, model.clone(), cx);
|
||||
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
||||
}
|
||||
StopReason::EndTurn | StopReason::MaxTokens => {
|
||||
StopReason::EndTurn | StopReason::MaxTokens => {
|
||||
thread.project.update(cx, |project, cx| {
|
||||
project.set_agent_location(None, cx);
|
||||
});
|
||||
@@ -1827,7 +1776,9 @@ impl Thread {
|
||||
{
|
||||
let mut messages_to_remove = Vec::new();
|
||||
|
||||
for (ix, message) in thread.messages.iter().enumerate().rev() {
|
||||
for (ix, message) in
|
||||
thread.messages.iter().enumerate().rev()
|
||||
{
|
||||
messages_to_remove.push(message.id);
|
||||
|
||||
if message.role == Role::User {
|
||||
@@ -1835,7 +1786,9 @@ impl Thread {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(prev_message) = thread.messages.get(ix - 1) {
|
||||
if let Some(prev_message) =
|
||||
thread.messages.get(ix - 1)
|
||||
{
|
||||
if prev_message.role == Role::Assistant {
|
||||
break;
|
||||
}
|
||||
@@ -1850,14 +1803,16 @@ impl Thread {
|
||||
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
|
||||
header: "Language model refusal".into(),
|
||||
message: "Model refused to generate content for safety reasons.".into(),
|
||||
message:
|
||||
"Model refused to generate content for safety reasons."
|
||||
.into(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// We successfully completed, so cancel any remaining retries.
|
||||
thread.retry_state = None;
|
||||
},
|
||||
}
|
||||
Err(error) => {
|
||||
thread.project.update(cx, |project, cx| {
|
||||
project.set_agent_location(None, cx);
|
||||
@@ -1883,26 +1838,38 @@ impl Thread {
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::ModelRequestLimitReached { plan: error.plan },
|
||||
));
|
||||
} else if let Some(known_error) =
|
||||
error.downcast_ref::<LanguageModelKnownError>()
|
||||
} else if let Some(completion_error) =
|
||||
error.downcast_ref::<LanguageModelCompletionError>()
|
||||
{
|
||||
match known_error {
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
|
||||
use LanguageModelCompletionError::*;
|
||||
match &completion_error {
|
||||
PromptTooLarge { tokens, .. } => {
|
||||
let tokens = tokens.unwrap_or_else(|| {
|
||||
// We didn't get an exact token count from the API, so fall back on our estimate.
|
||||
thread
|
||||
.total_token_usage()
|
||||
.map(|usage| usage.total)
|
||||
.unwrap_or(0)
|
||||
// We know the context window was exceeded in practice, so if our estimate was
|
||||
// lower than max tokens, the estimate was wrong; return that we exceeded by 1.
|
||||
.max(model.max_token_count().saturating_add(1))
|
||||
});
|
||||
thread.exceeded_window_error = Some(ExceededWindowError {
|
||||
model_id: model.id(),
|
||||
token_count: *tokens,
|
||||
token_count: tokens,
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
LanguageModelKnownError::RateLimitExceeded { retry_after } => {
|
||||
let provider_name = model.provider_name();
|
||||
let error_message = format!(
|
||||
"{}'s API rate limit exceeded",
|
||||
provider_name.0.as_ref()
|
||||
);
|
||||
|
||||
RateLimitExceeded {
|
||||
retry_after: Some(retry_after),
|
||||
..
|
||||
}
|
||||
| ServerOverloaded {
|
||||
retry_after: Some(retry_after),
|
||||
..
|
||||
} => {
|
||||
thread.handle_rate_limit_error(
|
||||
&error_message,
|
||||
&completion_error,
|
||||
*retry_after,
|
||||
model.clone(),
|
||||
intent,
|
||||
@@ -1911,15 +1878,9 @@ impl Thread {
|
||||
);
|
||||
retry_scheduled = true;
|
||||
}
|
||||
LanguageModelKnownError::Overloaded => {
|
||||
let provider_name = model.provider_name();
|
||||
let error_message = format!(
|
||||
"{}'s API servers are overloaded right now",
|
||||
provider_name.0.as_ref()
|
||||
);
|
||||
|
||||
RateLimitExceeded { .. } | ServerOverloaded { .. } => {
|
||||
retry_scheduled = thread.handle_retryable_error(
|
||||
&error_message,
|
||||
&completion_error,
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
@@ -1929,15 +1890,11 @@ impl Thread {
|
||||
emit_generic_error(error, cx);
|
||||
}
|
||||
}
|
||||
LanguageModelKnownError::ApiInternalServerError => {
|
||||
let provider_name = model.provider_name();
|
||||
let error_message = format!(
|
||||
"{}'s API server reported an internal server error",
|
||||
provider_name.0.as_ref()
|
||||
);
|
||||
|
||||
ApiInternalServerError { .. }
|
||||
| ApiReadResponseError { .. }
|
||||
| HttpSend { .. } => {
|
||||
retry_scheduled = thread.handle_retryable_error(
|
||||
&error_message,
|
||||
&completion_error,
|
||||
model.clone(),
|
||||
intent,
|
||||
window,
|
||||
@@ -1947,12 +1904,16 @@ impl Thread {
|
||||
emit_generic_error(error, cx);
|
||||
}
|
||||
}
|
||||
LanguageModelKnownError::ReadResponseError(_) |
|
||||
LanguageModelKnownError::DeserializeResponse(_) |
|
||||
LanguageModelKnownError::UnknownResponseFormat(_) => {
|
||||
// In the future we will attempt to re-roll response, but only once
|
||||
emit_generic_error(error, cx);
|
||||
}
|
||||
NoApiKey { .. }
|
||||
| HttpResponseError { .. }
|
||||
| BadRequestFormat { .. }
|
||||
| AuthenticationError { .. }
|
||||
| PermissionError { .. }
|
||||
| ApiEndpointNotFound { .. }
|
||||
| SerializeRequest { .. }
|
||||
| BuildRequestBody { .. }
|
||||
| DeserializeResponse { .. }
|
||||
| Other { .. } => emit_generic_error(error, cx),
|
||||
}
|
||||
} else {
|
||||
emit_generic_error(error, cx);
|
||||
@@ -2084,7 +2045,7 @@ impl Thread {
|
||||
|
||||
fn handle_rate_limit_error(
|
||||
&mut self,
|
||||
error_message: &str,
|
||||
error: &LanguageModelCompletionError,
|
||||
retry_after: Duration,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
@@ -2092,9 +2053,10 @@ impl Thread {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// For rate limit errors, we only retry once with the specified duration
|
||||
let retry_message = format!(
|
||||
"{error_message}. Retrying in {} seconds…",
|
||||
retry_after.as_secs()
|
||||
let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
|
||||
log::warn!(
|
||||
"Retrying completion request in {} seconds: {error:?}",
|
||||
retry_after.as_secs(),
|
||||
);
|
||||
|
||||
// Add a UI-only message instead of a regular message
|
||||
@@ -2127,18 +2089,18 @@ impl Thread {
|
||||
|
||||
fn handle_retryable_error(
|
||||
&mut self,
|
||||
error_message: &str,
|
||||
error: &LanguageModelCompletionError,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
|
||||
self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
|
||||
}
|
||||
|
||||
fn handle_retryable_error_with_delay(
|
||||
&mut self,
|
||||
error_message: &str,
|
||||
error: &LanguageModelCompletionError,
|
||||
custom_delay: Option<Duration>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
intent: CompletionIntent,
|
||||
@@ -2168,8 +2130,12 @@ impl Thread {
|
||||
// Add a transient message to inform the user
|
||||
let delay_secs = delay.as_secs();
|
||||
let retry_message = format!(
|
||||
"{}. Retrying (attempt {} of {}) in {} seconds...",
|
||||
error_message, attempt, max_attempts, delay_secs
|
||||
"{error}. Retrying (attempt {attempt} of {max_attempts}) \
|
||||
in {delay_secs} seconds..."
|
||||
);
|
||||
log::warn!(
|
||||
"Retrying completion request (attempt {attempt} of {max_attempts}) \
|
||||
in {delay_secs} seconds: {error:?}",
|
||||
);
|
||||
|
||||
// Add a UI-only message instead of a regular message
|
||||
@@ -4108,10 +4074,6 @@ fn main() {{
|
||||
self.inner.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
self.inner.max_image_size()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
self.inner.telemetry_id()
|
||||
}
|
||||
@@ -4143,9 +4105,15 @@ fn main() {{
|
||||
>,
|
||||
> {
|
||||
let error = match self.error_type {
|
||||
TestError::Overloaded => LanguageModelCompletionError::Overloaded,
|
||||
TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: self.provider_name(),
|
||||
retry_after: None,
|
||||
},
|
||||
TestError::InternalServerError => {
|
||||
LanguageModelCompletionError::ApiInternalServerError
|
||||
LanguageModelCompletionError::ApiInternalServerError {
|
||||
provider: self.provider_name(),
|
||||
message: "I'm a teapot orbiting the sun".to_string(),
|
||||
}
|
||||
}
|
||||
};
|
||||
async move {
|
||||
@@ -4621,10 +4589,6 @@ fn main() {{
|
||||
self.inner.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
self.inner.max_image_size()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
self.inner.telemetry_id()
|
||||
}
|
||||
@@ -4657,9 +4621,13 @@ fn main() {{
|
||||
> {
|
||||
if !*self.failed_once.lock() {
|
||||
*self.failed_once.lock() = true;
|
||||
let provider = self.provider_name();
|
||||
// Return error on first attempt
|
||||
let stream = futures::stream::once(async move {
|
||||
Err(LanguageModelCompletionError::Overloaded)
|
||||
Err(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
})
|
||||
});
|
||||
async move { Ok(stream.boxed()) }.boxed()
|
||||
} else {
|
||||
@@ -4790,10 +4758,6 @@ fn main() {{
|
||||
self.inner.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
self.inner.max_image_size()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
self.inner.telemetry_id()
|
||||
}
|
||||
@@ -4826,9 +4790,13 @@ fn main() {{
|
||||
> {
|
||||
if !*self.failed_once.lock() {
|
||||
*self.failed_once.lock() = true;
|
||||
let provider = self.provider_name();
|
||||
// Return error on first attempt
|
||||
let stream = futures::stream::once(async move {
|
||||
Err(LanguageModelCompletionError::Overloaded)
|
||||
Err(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
})
|
||||
});
|
||||
async move { Ok(stream.boxed()) }.boxed()
|
||||
} else {
|
||||
@@ -4951,10 +4919,6 @@ fn main() {{
|
||||
self.inner.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
self.inner.max_image_size()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
self.inner.telemetry_id()
|
||||
}
|
||||
@@ -4985,10 +4949,12 @@ fn main() {{
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let provider = self.provider_name();
|
||||
async move {
|
||||
let stream = futures::stream::once(async move {
|
||||
Err(LanguageModelCompletionError::RateLimitExceeded {
|
||||
retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
|
||||
provider,
|
||||
retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
|
||||
})
|
||||
});
|
||||
Ok(stream.boxed())
|
||||
@@ -5398,192 +5364,4 @@ fn main() {{
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_image_size_limit_in_thread(cx: &mut TestAppContext) {
|
||||
use gpui::DevicePixels;
|
||||
use language_model::{
|
||||
LanguageModelImage,
|
||||
fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
|
||||
};
|
||||
|
||||
init_test_settings(cx);
|
||||
let project = create_test_project(cx, serde_json::json!({})).await;
|
||||
let (_, _, thread, _, _) = setup_test_environment(cx, project).await;
|
||||
|
||||
// Create a small image that's under the limit
|
||||
let small_image = LanguageModelImage {
|
||||
source: "small_data".into(),
|
||||
size: gpui::size(DevicePixels(10), DevicePixels(10)),
|
||||
};
|
||||
|
||||
// Create a large image that exceeds typical limits (10MB)
|
||||
let large_image_source = "x".repeat(10_485_760); // 10MB
|
||||
let large_image = LanguageModelImage {
|
||||
source: large_image_source.into(),
|
||||
size: gpui::size(DevicePixels(1024), DevicePixels(1024)),
|
||||
};
|
||||
|
||||
// Create a loaded context with both images
|
||||
let loaded_context = ContextLoadResult {
|
||||
loaded_context: LoadedContext {
|
||||
contexts: vec![],
|
||||
text: "Test message".to_string(),
|
||||
images: vec![small_image.clone(), large_image.clone()],
|
||||
},
|
||||
referenced_buffers: HashSet::default(),
|
||||
};
|
||||
|
||||
// Insert a user message with the loaded context
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Test with images", loaded_context, None, vec![], cx);
|
||||
});
|
||||
|
||||
// Create a model with 500KB image size limit
|
||||
let _provider = Arc::new(FakeLanguageModelProvider);
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
// Note: FakeLanguageModel doesn't support images by default (max_image_size returns 0)
|
||||
// so we'll test that images are excluded when the model doesn't support them
|
||||
|
||||
// Generate the completion request
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
|
||||
});
|
||||
|
||||
// Verify that no images were included (because FakeLanguageModel doesn't support images)
|
||||
let mut image_count = 0;
|
||||
let mut has_text = false;
|
||||
for message in &request.messages {
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
if text.contains("Test message") {
|
||||
has_text = true;
|
||||
}
|
||||
}
|
||||
MessageContent::Image(_) => {
|
||||
image_count += 1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(has_text, "Text content should be included");
|
||||
assert_eq!(
|
||||
image_count, 0,
|
||||
"No images should be included when model doesn't support them"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_image_size_limit_with_anthropic_model(_cx: &mut TestAppContext) {
|
||||
use gpui::{DevicePixels, SharedString};
|
||||
use language_model::{
|
||||
LanguageModelId, LanguageModelImage, LanguageModelName, LanguageModelProviderId,
|
||||
LanguageModelProviderName,
|
||||
};
|
||||
|
||||
// Test with a model that has specific size limits (like Anthropic's 5MB limit)
|
||||
// We'll create a simple test to verify the logic works correctly
|
||||
|
||||
// Create test images
|
||||
let small_image = LanguageModelImage {
|
||||
source: "small".into(),
|
||||
size: gpui::size(DevicePixels(100), DevicePixels(100)),
|
||||
};
|
||||
|
||||
let large_image_source = "x".repeat(6_000_000); // 6MB - over Anthropic's 5MB limit
|
||||
let large_image = LanguageModelImage {
|
||||
source: large_image_source.into(),
|
||||
size: gpui::size(DevicePixels(2000), DevicePixels(2000)),
|
||||
};
|
||||
|
||||
let loaded_context = LoadedContext {
|
||||
contexts: vec![],
|
||||
text: "Test".to_string(),
|
||||
images: vec![small_image.clone(), large_image.clone()],
|
||||
};
|
||||
|
||||
// Test the add_to_request_message_with_model method directly
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
// Use the test from context.rs as a guide - create a mock model with 5MB limit
|
||||
struct TestModel5MB;
|
||||
impl language_model::LanguageModel for TestModel5MB {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(SharedString::from("test"))
|
||||
}
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(SharedString::from("Test 5MB"))
|
||||
}
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(SharedString::from("test"))
|
||||
}
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(SharedString::from("Test"))
|
||||
}
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_tool_choice(&self, _: language_model::LanguageModelToolChoice) -> bool {
|
||||
false
|
||||
}
|
||||
fn max_image_size(&self) -> u64 {
|
||||
5_242_880 // 5MB like Anthropic
|
||||
}
|
||||
fn telemetry_id(&self) -> String {
|
||||
"test".to_string()
|
||||
}
|
||||
fn max_token_count(&self) -> u64 {
|
||||
100_000
|
||||
}
|
||||
fn count_tokens(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<u64>> {
|
||||
Box::pin(async { Ok(0) })
|
||||
}
|
||||
fn stream_completion(
|
||||
&self,
|
||||
_request: language_model::LanguageModelRequest,
|
||||
_cx: &gpui::AsyncApp,
|
||||
) -> futures::future::BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<
|
||||
language_model::LanguageModelCompletionEvent,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
language_model::LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
Box::pin(async {
|
||||
Err(language_model::LanguageModelCompletionError::Other(
|
||||
anyhow::anyhow!("Not implemented"),
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn language_model::LanguageModel> = Arc::new(TestModel5MB);
|
||||
loaded_context.add_to_request_message_with_model(&mut request_message, &model);
|
||||
|
||||
// Should have text and only the small image
|
||||
let mut image_count = 0;
|
||||
for content in &request_message.content {
|
||||
if matches!(content, MessageContent::Image(_)) {
|
||||
image_count += 1;
|
||||
}
|
||||
}
|
||||
assert_eq!(image_count, 1, "Only the small image should be included");
|
||||
}
|
||||
}
|
||||
|
||||
49
crates/agent2/Cargo.toml
Normal file
49
crates/agent2/Cargo.toml
Normal file
@@ -0,0 +1,49 @@
|
||||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
||||
1
crates/agent2/LICENSE-GPL
Symbolic link
1
crates/agent2/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
6
crates/agent2/src/agent2.rs
Normal file
6
crates/agent2/src/agent2.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
|
||||
pub use thread::*;
|
||||
29
crates/agent2/src/prompts.rs
Normal file
29
crates/agent2/src/prompts.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
||||
57
crates/agent2/src/templates.rs
Normal file
57
crates/agent2/src/templates.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
}
|
||||
56
crates/agent2/src/templates/base.hbs
Normal file
56
crates/agent2/src/templates/base.hbs
Normal file
@@ -0,0 +1,56 @@
|
||||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
8
crates/agent2/src/templates/glob.hbs
Normal file
8
crates/agent2/src/templates/glob.hbs
Normal file
@@ -0,0 +1,8 @@
|
||||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
||||
420
crates/agent2/src/thread.rs
Normal file
420
crates/agent2/src/thread.rs
Normal file
@@ -0,0 +1,420 @@
|
||||
use crate::templates::Templates;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{channel::mpsc, future};
|
||||
use gpui::{App, Context, SharedString, Task};
|
||||
use language_model::{
|
||||
CompletionIntent, CompletionMode, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AgentToolErased>>,
|
||||
templates: Arc<Templates>,
|
||||
// project: Entity<Project>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: Vec::new(),
|
||||
running_turn: None,
|
||||
tools: BTreeMap::default(),
|
||||
templates,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let system_message = self.build_system_message(cx);
|
||||
self.messages.extend(system_message);
|
||||
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content.into()],
|
||||
});
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
let mut tool_uses = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// If there are tool uses, wait for their results to be
|
||||
// computed, then send them together in a single message on
|
||||
// the next loop iteration.
|
||||
let tool_results = future::join_all(tool_uses).await;
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: tool_results.into_iter().map(Into::into).collect(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
}
|
||||
|
||||
(!system_message.content.is_empty()).then_some(system_message)
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match event {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking { text, signature } => {
|
||||
todo!()
|
||||
}
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, cx);
|
||||
}
|
||||
StartMessage { role, .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(_) => {}
|
||||
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
||||
StatusUpdate(_completion_request_status) => {}
|
||||
RedactedThinking { data } => todo!(),
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => todo!(),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
||||
match stop_reason {
|
||||
StopReason::EndTurn | StopReason::ToolUse => {}
|
||||
StopReason::MaxTokens => todo!(),
|
||||
StopReason::Refusal => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if push_new_tool_use {
|
||||
last_message.content.push(tool_use.clone().into());
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(format!(
|
||||
"No tool named {} exists",
|
||||
tool_use.name
|
||||
))),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
) -> LanguageModelRequest {
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
messages: self.build_request_messages(),
|
||||
tools: self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
self.messages
|
||||
.iter()
|
||||
.map(|message| LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
assistant_tools::root_schema_for::<Self::Input>(format)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AgentToolErased> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub trait AgentToolErased {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AgentToolErased for Erased<Arc<T>>
|
||||
where
|
||||
T: AgentTool,
|
||||
{
|
||||
fn name(&self) -> SharedString {
|
||||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
254
crates/agent2/src/thread/tests.rs
Normal file
254
crates/agent2/src/thread/tests.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
use super::*;
|
||||
use client::{proto::language_server_prompt_request, Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, Entity, TestAppContext};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, MessageContent, StopReason,
|
||||
};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert_eq!(
|
||||
agent.messages.last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(EchoTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
|
||||
// Test a tool calls that's likely to complete *after* streaming stops.
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.remove_tool(&AgentTool::name(&EchoTool));
|
||||
agent.add_tool(DelayTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
agent.update(cx, |agent, _cx| {
|
||||
assert!(agent
|
||||
.messages
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = agent.update(cx, |agent, cx| {
|
||||
agent.add_tool(WordListTool);
|
||||
agent.send(model.clone(), "Test the word_list tool.", cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
|
||||
agent.update(cx, |agent, _cx| {
|
||||
// Look for a tool use in the agent's last message
|
||||
let last_content = agent.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_use_event.is_input_complete {
|
||||
last_tool_use
|
||||
.input
|
||||
.get("a")
|
||||
.expect("'a' has streamed because input is now complete");
|
||||
last_tool_use
|
||||
.input
|
||||
.get("g")
|
||||
.expect("'g' has streamed because input is now complete");
|
||||
} else {
|
||||
if !last_tool_use.is_input_complete
|
||||
&& last_tool_use.input.get("g").is_none()
|
||||
{
|
||||
saw_partial_tool_use = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("last content should be a tool use");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_partial_tool_use,
|
||||
"should see at least one partially streamed tool use in the history"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let AgentTest { model, agent, .. } = setup(cx).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = agent
|
||||
.update(cx, |agent, cx| {
|
||||
agent.add_tool(DelayTool);
|
||||
agent.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
if stop_reasons.len() == 2 {
|
||||
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
||||
} else if stop_reasons.len() == 3 {
|
||||
assert_eq!(
|
||||
stop_reasons,
|
||||
vec![
|
||||
StopReason::ToolUse,
|
||||
StopReason::ToolUse,
|
||||
StopReason::EndTurn
|
||||
]
|
||||
);
|
||||
} else {
|
||||
panic!("Expected either 1 or 2 tool uses followed by end turn");
|
||||
}
|
||||
|
||||
agent.update(cx, |agent, _cx| {
|
||||
let last_message = agent.messages.last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct AgentTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
agent: Entity<Thread>,
|
||||
}
|
||||
|
||||
async fn setup(cx: &mut TestAppContext) -> AgentTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
// let project = Project::test(fs.clone(), [], cx).await;
|
||||
// let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let templates = Templates::new();
|
||||
let agent = cx.new(|_| Thread::new(templates));
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
AgentTest { model, agent }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
83
crates/agent2/src/thread/tests/test_tools.rs
Normal file
83
crates/agent2/src/thread/tests/test_tools.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use super::*;
|
||||
|
||||
/// A tool that echoes its input
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that waits for a specified delay
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
/// The delay in milliseconds.
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
cx.foreground_executor().spawn(async move {
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that takes an object with map from letters to random words starting with that letter.
|
||||
/// All fiealds are required! Pass a word for every letter!
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct WordListInput {
|
||||
/// Provide a random word that starts with A.
|
||||
a: Option<String>,
|
||||
/// Provide a random word that starts with B.
|
||||
b: Option<String>,
|
||||
/// Provide a random word that starts with C.
|
||||
c: Option<String>,
|
||||
/// Provide a random word that starts with D.
|
||||
d: Option<String>,
|
||||
/// Provide a random word that starts with E.
|
||||
e: Option<String>,
|
||||
/// Provide a random word that starts with F.
|
||||
f: Option<String>,
|
||||
/// Provide a random word that starts with G.
|
||||
g: Option<String>,
|
||||
}
|
||||
|
||||
pub struct WordListTool;
|
||||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
||||
1
crates/agent2/src/tools.rs
Normal file
1
crates/agent2/src/tools.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod glob;
|
||||
76
crates/agent2/src/tools/glob.rs
Normal file
76
crates/agent2/src/tools/glob.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,10 @@ use anyhow::{Result, bail};
|
||||
use collections::IndexMap;
|
||||
use gpui::{App, Pixels, SharedString};
|
||||
use language_model::LanguageModel;
|
||||
use schemars::{JsonSchema, schema::Schema};
|
||||
use schemars::{JsonSchema, json_schema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use std::borrow::Cow;
|
||||
|
||||
pub use crate::agent_profile::*;
|
||||
|
||||
@@ -49,7 +50,7 @@ pub struct AgentSettings {
|
||||
pub dock: AgentDockPosition,
|
||||
pub default_width: Pixels,
|
||||
pub default_height: Pixels,
|
||||
pub default_model: LanguageModelSelection,
|
||||
pub default_model: Option<LanguageModelSelection>,
|
||||
pub inline_assistant_model: Option<LanguageModelSelection>,
|
||||
pub commit_message_model: Option<LanguageModelSelection>,
|
||||
pub thread_summary_model: Option<LanguageModelSelection>,
|
||||
@@ -211,7 +212,6 @@ impl AgentSettingsContent {
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct AgentSettingsContent {
|
||||
/// Whether the Agent is enabled.
|
||||
///
|
||||
@@ -321,29 +321,27 @@ pub struct LanguageModelSelection {
|
||||
pub struct LanguageModelProviderSetting(pub String);
|
||||
|
||||
impl JsonSchema for LanguageModelProviderSetting {
|
||||
fn schema_name() -> String {
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"LanguageModelProviderSetting".into()
|
||||
}
|
||||
|
||||
fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> Schema {
|
||||
schemars::schema::SchemaObject {
|
||||
enum_values: Some(vec![
|
||||
"anthropic".into(),
|
||||
"amazon-bedrock".into(),
|
||||
"google".into(),
|
||||
"lmstudio".into(),
|
||||
"ollama".into(),
|
||||
"openai".into(),
|
||||
"zed.dev".into(),
|
||||
"copilot_chat".into(),
|
||||
"deepseek".into(),
|
||||
"openrouter".into(),
|
||||
"mistral".into(),
|
||||
"vercel".into(),
|
||||
]),
|
||||
..Default::default()
|
||||
}
|
||||
.into()
|
||||
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"enum": [
|
||||
"anthropic",
|
||||
"amazon-bedrock",
|
||||
"google",
|
||||
"lmstudio",
|
||||
"ollama",
|
||||
"openai",
|
||||
"zed.dev",
|
||||
"copilot_chat",
|
||||
"deepseek",
|
||||
"openrouter",
|
||||
"mistral",
|
||||
"vercel"
|
||||
]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,15 +357,6 @@ impl From<&str> for LanguageModelProviderSetting {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LanguageModelSelection {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider: LanguageModelProviderSetting("openai".to_string()),
|
||||
model: "gpt-4".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentProfileContent {
|
||||
pub name: Arc<str>,
|
||||
@@ -411,7 +400,10 @@ impl Settings for AgentSettings {
|
||||
&mut settings.default_height,
|
||||
value.default_height.map(Into::into),
|
||||
);
|
||||
merge(&mut settings.default_model, value.default_model.clone());
|
||||
settings.default_model = value
|
||||
.default_model
|
||||
.clone()
|
||||
.or(settings.default_model.take());
|
||||
settings.inline_assistant_model = value
|
||||
.inline_assistant_model
|
||||
.clone()
|
||||
|
||||
@@ -889,46 +889,6 @@ impl ActiveThread {
|
||||
&self.text_thread_store
|
||||
}
|
||||
|
||||
pub fn validate_image(&self, image: &Arc<gpui::Image>, cx: &App) -> Result<(), String> {
|
||||
let image_size = image.bytes().len() as u64;
|
||||
|
||||
if let Some(model) = self.thread.read(cx).configured_model() {
|
||||
let max_size = model.model.max_image_size();
|
||||
|
||||
if image_size > max_size {
|
||||
if max_size == 0 {
|
||||
Err(format!(
|
||||
"{} does not support image attachments",
|
||||
model.model.name().0
|
||||
))
|
||||
} else {
|
||||
let size_mb = image_size as f64 / 1_048_576.0;
|
||||
let max_size_mb = max_size as f64 / 1_048_576.0;
|
||||
Err(format!(
|
||||
"Image ({:.1} MB) exceeds {}'s {:.1} MB size limit",
|
||||
size_mb,
|
||||
model.model.name().0,
|
||||
max_size_mb
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
// No model configured, use default 10MB limit
|
||||
const DEFAULT_MAX_SIZE: u64 = 10 * 1024 * 1024;
|
||||
if image_size > DEFAULT_MAX_SIZE {
|
||||
let size_mb = image_size as f64 / 1_048_576.0;
|
||||
Err(format!(
|
||||
"Image ({:.1} MB) exceeds the 10 MB size limit",
|
||||
size_mb
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn push_rendered_message(&mut self, id: MessageId, rendered_message: RenderedMessage) {
|
||||
let old_len = self.messages.len();
|
||||
self.messages.push(id);
|
||||
@@ -1562,7 +1522,7 @@ impl ActiveThread {
|
||||
}
|
||||
|
||||
fn paste(&mut self, _: &Paste, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
attach_pasted_images_as_context_with_validation(&self.context_store, Some(self), cx);
|
||||
attach_pasted_images_as_context(&self.context_store, cx);
|
||||
}
|
||||
|
||||
fn cancel_editing_message(
|
||||
@@ -3743,14 +3703,6 @@ pub(crate) fn open_context(
|
||||
pub(crate) fn attach_pasted_images_as_context(
|
||||
context_store: &Entity<ContextStore>,
|
||||
cx: &mut App,
|
||||
) -> bool {
|
||||
attach_pasted_images_as_context_with_validation(context_store, None, cx)
|
||||
}
|
||||
|
||||
pub(crate) fn attach_pasted_images_as_context_with_validation(
|
||||
context_store: &Entity<ContextStore>,
|
||||
active_thread: Option<&ActiveThread>,
|
||||
cx: &mut App,
|
||||
) -> bool {
|
||||
let images = cx
|
||||
.read_from_clipboard()
|
||||
@@ -3772,67 +3724,9 @@ pub(crate) fn attach_pasted_images_as_context_with_validation(
|
||||
}
|
||||
cx.stop_propagation();
|
||||
|
||||
// Try to find the workspace for showing toasts
|
||||
let workspace = cx
|
||||
.active_window()
|
||||
.and_then(|window| window.downcast::<Workspace>());
|
||||
|
||||
context_store.update(cx, |store, cx| {
|
||||
for image in images {
|
||||
let image_arc = Arc::new(image);
|
||||
|
||||
// Validate image if we have an active thread
|
||||
let should_add = if let Some(thread) = active_thread {
|
||||
match thread.validate_image(&image_arc, cx) {
|
||||
Ok(()) => true,
|
||||
Err(err) => {
|
||||
// Show error toast if we have a workspace
|
||||
if let Some(workspace) = workspace {
|
||||
let _ = workspace.update(cx, |workspace, _, cx| {
|
||||
use workspace::{Toast, notifications::NotificationId};
|
||||
|
||||
struct ImageRejectionToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<ImageRejectionToast>(),
|
||||
err,
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No active thread, check against default limit
|
||||
let image_size = image_arc.bytes().len() as u64;
|
||||
const DEFAULT_MAX_SIZE: u64 = 10 * 1024 * 1024; // 10MB
|
||||
|
||||
if image_size > DEFAULT_MAX_SIZE {
|
||||
let size_mb = image_size as f64 / 1_048_576.0;
|
||||
let err = format!("Image ({:.1} MB) exceeds the 10 MB size limit", size_mb);
|
||||
|
||||
if let Some(workspace) = workspace {
|
||||
let _ = workspace.update(cx, |workspace, _, cx| {
|
||||
use workspace::{Toast, notifications::NotificationId};
|
||||
|
||||
struct ImageRejectionToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(NotificationId::unique::<ImageRejectionToast>(), err),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if should_add {
|
||||
store.add_image_instance(image_arc, cx);
|
||||
}
|
||||
store.add_image_instance(Arc::new(image), cx);
|
||||
}
|
||||
});
|
||||
true
|
||||
|
||||
@@ -16,7 +16,9 @@ use gpui::{
|
||||
Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage,
|
||||
};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use language_model::{
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use notifications::status_toast::{StatusToast, ToastIcon};
|
||||
use project::{
|
||||
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
|
||||
@@ -86,6 +88,14 @@ impl AgentConfiguration {
|
||||
let scroll_handle = ScrollHandle::new();
|
||||
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
|
||||
|
||||
let mut expanded_provider_configurations = HashMap::default();
|
||||
if LanguageModelRegistry::read_global(cx)
|
||||
.provider(&ZED_CLOUD_PROVIDER_ID)
|
||||
.map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx))
|
||||
{
|
||||
expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true);
|
||||
}
|
||||
|
||||
let mut this = Self {
|
||||
fs,
|
||||
language_registry,
|
||||
@@ -94,7 +104,7 @@ impl AgentConfiguration {
|
||||
configuration_views_by_provider: HashMap::default(),
|
||||
context_server_store,
|
||||
expanded_context_server_tools: HashMap::default(),
|
||||
expanded_provider_configurations: HashMap::default(),
|
||||
expanded_provider_configurations,
|
||||
tools,
|
||||
_registry_subscription: registry_subscription,
|
||||
scroll_handle,
|
||||
|
||||
@@ -4,8 +4,6 @@ use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use gpui::{Image, ImageFormat};
|
||||
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -43,7 +41,7 @@ use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
|
||||
Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, FontWeight,
|
||||
Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, Hsla,
|
||||
KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, linear_color_stop,
|
||||
linear_gradient, prelude::*, pulsating_between,
|
||||
};
|
||||
@@ -61,7 +59,7 @@ use theme::ThemeSettings;
|
||||
use time::UtcOffset;
|
||||
use ui::utils::WithRemSize;
|
||||
use ui::{
|
||||
Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
|
||||
Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
|
||||
PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
@@ -2027,9 +2025,7 @@ impl AgentPanel {
|
||||
.thread()
|
||||
.read(cx)
|
||||
.configured_model()
|
||||
.map_or(false, |model| {
|
||||
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
|
||||
});
|
||||
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
|
||||
|
||||
if !is_using_zed_provider {
|
||||
return false;
|
||||
@@ -2602,7 +2598,7 @@ impl AgentPanel {
|
||||
Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => {
|
||||
parent.child(Banner::new().severity(ui::Severity::Warning).child(
|
||||
h_flex().w_full().children(provider.render_accept_terms(
|
||||
LanguageModelProviderTosView::ThreadtEmptyState,
|
||||
LanguageModelProviderTosView::ThreadEmptyState,
|
||||
cx,
|
||||
)),
|
||||
))
|
||||
@@ -2693,58 +2689,90 @@ impl AgentPanel {
|
||||
Some(div().px_2().pb_2().child(banner).into_any_element())
|
||||
}
|
||||
|
||||
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||
let message = message.into();
|
||||
|
||||
IconButton::new("copy", IconName::Copy)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Copy Error Message"))
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
fn dismiss_error_button(
|
||||
&self,
|
||||
thread: &Entity<ActiveThread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Dismiss Error"))
|
||||
.on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn upgrade_button(
|
||||
&self,
|
||||
thread: &Entity<ActiveThread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
Button::new("upgrade", "Upgrade")
|
||||
.label_size(LabelSize::Small)
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn error_callout_bg(&self, cx: &Context<Self>) -> Hsla {
|
||||
cx.theme().status().error.opacity(0.08)
|
||||
}
|
||||
|
||||
fn render_payment_required_error(
|
||||
&self,
|
||||
thread: &Entity<ActiveThread>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
|
||||
const ERROR_MESSAGE: &str =
|
||||
"You reached your free usage limit. Upgrade to Zed Pro for more prompts.";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error);
|
||||
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
}
|
||||
})))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}))),
|
||||
div()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
Callout::new()
|
||||
.icon(icon)
|
||||
.title("Free Usage Exceeded")
|
||||
.description(ERROR_MESSAGE)
|
||||
.tertiary_action(self.upgrade_button(thread, cx))
|
||||
.secondary_action(self.create_copy_button(ERROR_MESSAGE))
|
||||
.primary_action(self.dismiss_error_button(thread, cx))
|
||||
.bg_color(self.error_callout_bg(cx)),
|
||||
)
|
||||
.into_any()
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_model_request_limit_reached_error(
|
||||
@@ -2754,67 +2782,28 @@ impl AgentPanel {
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let error_message = match plan {
|
||||
Plan::ZedPro => {
|
||||
"Model request limit reached. Upgrade to usage-based billing for more requests."
|
||||
}
|
||||
Plan::ZedProTrial => {
|
||||
"Model request limit reached. Upgrade to Zed Pro for more requests."
|
||||
}
|
||||
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
|
||||
};
|
||||
let call_to_action = match plan {
|
||||
Plan::ZedPro => "Upgrade to usage-based billing",
|
||||
Plan::ZedProTrial => "Upgrade to Zed Pro",
|
||||
Plan::Free => "Upgrade to Zed Pro",
|
||||
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
|
||||
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.",
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Model Request Limit Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(error_message))
|
||||
.child(
|
||||
Button::new("subscribe", call_to_action).on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error);
|
||||
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}))),
|
||||
div()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
Callout::new()
|
||||
.icon(icon)
|
||||
.title("Model Prompt Limit Reached")
|
||||
.description(error_message)
|
||||
.tertiary_action(self.upgrade_button(thread, cx))
|
||||
.secondary_action(self.create_copy_button(error_message))
|
||||
.primary_action(self.dismiss_error_button(thread, cx))
|
||||
.bg_color(self.error_callout_bg(cx)),
|
||||
)
|
||||
.into_any()
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_error_message(
|
||||
@@ -2825,40 +2814,24 @@ impl AgentPanel {
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let message_with_header = format!("{}\n{}", header, message);
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new(header).weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_32()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(message.clone())),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(message_with_header))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
|
||||
let thread = thread.clone();
|
||||
move |_, _, _, cx| {
|
||||
thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}))),
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error);
|
||||
|
||||
div()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
Callout::new()
|
||||
.icon(icon)
|
||||
.title(header)
|
||||
.description(message.clone())
|
||||
.primary_action(self.dismiss_error_button(thread, cx))
|
||||
.secondary_action(self.create_copy_button(message_with_header))
|
||||
.bg_color(self.error_callout_bg(cx)),
|
||||
)
|
||||
.into_any()
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_prompt_editor(
|
||||
@@ -2934,215 +2907,29 @@ impl AgentPanel {
|
||||
}),
|
||||
)
|
||||
.on_drop(cx.listener(move |this, paths: &ExternalPaths, window, cx| {
|
||||
eprintln!("=== ON_DROP EXTERNAL_PATHS HANDLER ===");
|
||||
eprintln!("Number of external paths: {}", paths.paths().len());
|
||||
for (i, path) in paths.paths().iter().enumerate() {
|
||||
eprintln!("External path {}: {:?}", i, path);
|
||||
}
|
||||
|
||||
match &this.active_view {
|
||||
ActiveView::Thread { thread, .. } => {
|
||||
eprintln!("In ActiveView::Thread branch");
|
||||
let thread = thread.clone();
|
||||
let paths = paths.paths();
|
||||
let workspace = this.workspace.clone();
|
||||
|
||||
for path in paths {
|
||||
eprintln!("Processing path: {:?}", path);
|
||||
// Check if it's an image file by extension
|
||||
let is_image = path.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.map(|ext| {
|
||||
matches!(
|
||||
ext.to_lowercase().as_str(),
|
||||
"jpg" | "jpeg" | "png" | "gif" | "webp" | "bmp" | "ico" | "svg" | "tiff" | "tif"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
eprintln!("Is image: {}", is_image);
|
||||
|
||||
if is_image {
|
||||
let path = path.to_path_buf();
|
||||
let thread = thread.clone();
|
||||
let workspace = workspace.clone();
|
||||
eprintln!("Spawning async task for image: {:?}", path);
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
eprintln!("=== INSIDE ASYNC IMAGE TASK ===");
|
||||
eprintln!("Image path: {:?}", path);
|
||||
// Get file metadata first
|
||||
let metadata = smol::fs::metadata(&path).await;
|
||||
eprintln!("Metadata result: {:?}", metadata.is_ok());
|
||||
|
||||
if let Ok(metadata) = metadata {
|
||||
let file_size = metadata.len();
|
||||
eprintln!("File size: {} bytes", file_size);
|
||||
|
||||
// Get model limits
|
||||
let (max_image_size, model_name) = thread
|
||||
.update_in(cx, |thread, _window, cx| {
|
||||
let model = thread.thread().read(cx).configured_model();
|
||||
let max_size = model
|
||||
.as_ref()
|
||||
.map(|m| m.model.max_image_size())
|
||||
.unwrap_or(10 * 1024 * 1024);
|
||||
let name = model.as_ref().map(|m| m.model.name().0.to_string());
|
||||
(max_size, name)
|
||||
})
|
||||
.ok()
|
||||
.unwrap_or((10 * 1024 * 1024, None));
|
||||
|
||||
eprintln!("Max image size: {}, Model: {:?}", max_image_size, model_name);
|
||||
eprintln!("File size: {:.2} MB, Limit: {:.2} MB",
|
||||
file_size as f64 / 1_048_576.0,
|
||||
max_image_size as f64 / 1_048_576.0);
|
||||
|
||||
if file_size > max_image_size {
|
||||
eprintln!("FILE SIZE EXCEEDS LIMIT!");
|
||||
let error_message = if let Some(model_name) = &model_name {
|
||||
if max_image_size == 0 {
|
||||
format!("{} does not support image attachments", model_name)
|
||||
} else {
|
||||
let size_mb = file_size as f64 / 1_048_576.0;
|
||||
let max_size_mb = max_image_size as f64 / 1_048_576.0;
|
||||
format!(
|
||||
"Image ({:.1} MB) exceeds {}'s {:.1} MB size limit",
|
||||
size_mb, model_name, max_size_mb
|
||||
)
|
||||
}
|
||||
} else {
|
||||
let size_mb = file_size as f64 / 1_048_576.0;
|
||||
format!("Image ({:.1} MB) exceeds the 10 MB size limit", size_mb)
|
||||
};
|
||||
|
||||
eprintln!("Showing error toast: {}", error_message);
|
||||
|
||||
cx.update(|_, cx| {
|
||||
eprintln!("Inside cx.update for toast");
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
eprintln!("Got workspace, showing toast!");
|
||||
let _ = workspace.update(cx, |workspace, cx| {
|
||||
use workspace::{Toast, notifications::NotificationId};
|
||||
|
||||
struct ImageRejectionToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<ImageRejectionToast>(),
|
||||
error_message,
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
eprintln!("Toast command issued!");
|
||||
} else {
|
||||
eprintln!("FAILED to upgrade workspace!");
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
} else {
|
||||
eprintln!("Image within size limits, loading file");
|
||||
// Load the image file
|
||||
match smol::fs::read(&path).await {
|
||||
Ok(data) => {
|
||||
eprintln!("Successfully read {} bytes", data.len());
|
||||
// Determine image format from extension
|
||||
let format = path.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.and_then(|ext| {
|
||||
match ext.to_lowercase().as_str() {
|
||||
"png" => Some(ImageFormat::Png),
|
||||
"jpg" | "jpeg" => Some(ImageFormat::Jpeg),
|
||||
"gif" => Some(ImageFormat::Gif),
|
||||
"webp" => Some(ImageFormat::Webp),
|
||||
"bmp" => Some(ImageFormat::Bmp),
|
||||
"svg" => Some(ImageFormat::Svg),
|
||||
"tiff" | "tif" => Some(ImageFormat::Tiff),
|
||||
_ => None
|
||||
}
|
||||
})
|
||||
.unwrap_or(ImageFormat::Png); // Default to PNG if unknown
|
||||
|
||||
// Create image from data
|
||||
let image = Image::from_bytes(format, data);
|
||||
let image_arc = Arc::new(image);
|
||||
|
||||
// Add to context store
|
||||
thread
|
||||
.update_in(cx, |thread, _window, cx| {
|
||||
thread.context_store().update(cx, |store, cx| {
|
||||
store.add_image_instance(image_arc, cx);
|
||||
});
|
||||
})
|
||||
.log_err();
|
||||
eprintln!("Image added to context store!");
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Failed to read image file: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("Failed to get file metadata!");
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
eprintln!("Image task detached");
|
||||
} else {
|
||||
eprintln!("Not an image, using project path logic");
|
||||
// For non-image files, use the existing project path logic
|
||||
let project = this.project.clone();
|
||||
let context_store = thread.read(cx).context_store().clone();
|
||||
let path = path.to_path_buf();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
if let Some(task) = cx.update(|_, cx| {
|
||||
Workspace::project_path_for_path(project.clone(), &path, false, cx)
|
||||
}).ok() {
|
||||
if let Some((_, project_path)) = task.await.log_err() {
|
||||
context_store
|
||||
.update(cx, |store, cx| {
|
||||
store.add_file_from_path(project_path, false, cx).detach();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
let tasks = paths
|
||||
.paths()
|
||||
.into_iter()
|
||||
.map(|path| {
|
||||
Workspace::project_path_for_path(this.project.clone(), &path, false, cx)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = vec![];
|
||||
let mut added_worktrees = vec![];
|
||||
let opened_paths = futures::future::join_all(tasks).await;
|
||||
for entry in opened_paths {
|
||||
if let Some((worktree, project_path)) = entry.log_err() {
|
||||
added_worktrees.push(worktree);
|
||||
paths.push(project_path);
|
||||
}
|
||||
}
|
||||
ActiveView::TextThread { .. } => {
|
||||
eprintln!("In ActiveView::TextThread branch");
|
||||
// Keep existing behavior for text threads
|
||||
let tasks = paths
|
||||
.paths()
|
||||
.into_iter()
|
||||
.map(|path| {
|
||||
Workspace::project_path_for_path(this.project.clone(), &path, false, cx)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = vec![];
|
||||
let mut added_worktrees = vec![];
|
||||
let opened_paths = futures::future::join_all(tasks).await;
|
||||
|
||||
for entry in opened_paths {
|
||||
if let Some((worktree, project_path)) = entry.log_err() {
|
||||
added_worktrees.push(worktree);
|
||||
paths.push(project_path);
|
||||
}
|
||||
}
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.handle_drop(paths, added_worktrees, window, cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
_ => {
|
||||
eprintln!("In unknown ActiveView branch");
|
||||
}
|
||||
}
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.handle_drop(paths, added_worktrees, window, cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -3153,47 +2940,20 @@ impl AgentPanel {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// This method is now only used for non-image files and text threads
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { thread, .. } => {
|
||||
let context_store = thread.read(cx).context_store().clone();
|
||||
|
||||
// All paths here should be non-image files
|
||||
context_store.update(cx, move |context_store, cx| {
|
||||
let mut tasks = Vec::new();
|
||||
for path in paths {
|
||||
tasks.push(context_store.add_file_from_path(path, false, cx));
|
||||
for project_path in &paths {
|
||||
tasks.push(context_store.add_file_from_path(
|
||||
project_path.clone(),
|
||||
false,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
|
||||
// Show error toasts for any file errors
|
||||
for result in results {
|
||||
if let Err(err) = result {
|
||||
cx.update(|cx| {
|
||||
if let Some(workspace) = cx
|
||||
.active_window()
|
||||
.and_then(|window| window.downcast::<Workspace>())
|
||||
{
|
||||
let _ = workspace.update(cx, |workspace, _, cx| {
|
||||
use workspace::{Toast, notifications::NotificationId};
|
||||
|
||||
struct FileLoadErrorToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(
|
||||
NotificationId::unique::<FileLoadErrorToast>(),
|
||||
err.to_string(),
|
||||
),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
cx.background_spawn(async move {
|
||||
futures::future::join_all(tasks).await;
|
||||
// Need to hold onto the worktrees until they have already been used when
|
||||
// opening the buffers.
|
||||
drop(added_worktrees);
|
||||
@@ -3216,15 +2976,6 @@ impl AgentPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||
let message = message.into();
|
||||
IconButton::new("copy", IconName::Copy)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||
})
|
||||
.tooltip(Tooltip::text("Copy Error Message"))
|
||||
}
|
||||
|
||||
fn key_context(&self) -> KeyContext {
|
||||
let mut key_context = KeyContext::new_with_defaults();
|
||||
key_context.add("AgentPanel");
|
||||
@@ -3306,18 +3057,9 @@ impl Render for AgentPanel {
|
||||
thread.clone().into_any_element()
|
||||
})
|
||||
.children(self.render_tool_use_limit_reached(window, cx))
|
||||
.child(h_flex().child(message_editor.clone()))
|
||||
.when_some(thread.read(cx).last_error(), |this, last_error| {
|
||||
this.child(
|
||||
div()
|
||||
.absolute()
|
||||
.right_3()
|
||||
.bottom_12()
|
||||
.max_w_96()
|
||||
.py_2()
|
||||
.px_3()
|
||||
.elevation_2(cx)
|
||||
.occlude()
|
||||
.child(match last_error {
|
||||
ThreadError::PaymentRequired => {
|
||||
self.render_payment_required_error(thread, cx)
|
||||
@@ -3331,6 +3073,7 @@ impl Render for AgentPanel {
|
||||
.into_any(),
|
||||
)
|
||||
})
|
||||
.child(h_flex().child(message_editor.clone()))
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::History => parent.child(self.history.clone()),
|
||||
ActiveView::TextThread {
|
||||
|
||||
@@ -9,7 +9,6 @@ mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_strip;
|
||||
mod debug;
|
||||
|
||||
mod inline_assistant;
|
||||
mod inline_prompt_editor;
|
||||
mod language_model_selector;
|
||||
@@ -93,6 +92,7 @@ actions!(
|
||||
|
||||
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct NewThread {
|
||||
#[serde(default)]
|
||||
from_thread_id: Option<ThreadId>,
|
||||
@@ -100,6 +100,7 @@ pub struct NewThread {
|
||||
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = agent)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ManageProfiles {
|
||||
#[serde(default)]
|
||||
pub customize_tools: Option<AgentProfileId>,
|
||||
@@ -210,7 +211,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
}
|
||||
}
|
||||
|
||||
let default = to_selected_model(&settings.default_model);
|
||||
let default = settings.default_model.as_ref().map(to_selected_model);
|
||||
let inline_assistant = settings
|
||||
.inline_assistant_model
|
||||
.as_ref()
|
||||
@@ -230,7 +231,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.select_default_model(Some(&default), cx);
|
||||
registry.select_default_model(default.as_ref(), cx);
|
||||
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
|
||||
registry.select_commit_message_model(commit_message.as_ref(), cx);
|
||||
registry.select_thread_summary_model(thread_summary.as_ref(), cx);
|
||||
|
||||
@@ -656,10 +656,6 @@ mod tests {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("{}/{}", self.provider_id.0, self.name.0)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@ use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::AsyncApp;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
|
||||
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
@@ -63,7 +62,6 @@ use agent::{
|
||||
context_store::ContextStore,
|
||||
thread_store::{TextThreadStore, ThreadStore},
|
||||
};
|
||||
use workspace::{Toast, notifications::NotificationId};
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
pub struct MessageEditor {
|
||||
@@ -382,15 +380,11 @@ impl MessageEditor {
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let context_task = self.reload_context(cx);
|
||||
let window_handle = window.window_handle();
|
||||
let workspace = self.workspace.clone();
|
||||
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
|
||||
let loaded_context = loaded_context.unwrap_or_default();
|
||||
|
||||
// Check for rejected images and show notifications
|
||||
Self::notify_rejected_images(&loaded_context, &model, &workspace, &cx);
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
@@ -418,80 +412,6 @@ impl MessageEditor {
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn notify_rejected_images(
|
||||
loaded_context: &agent::context::ContextLoadResult,
|
||||
model: &Arc<dyn language_model::LanguageModel>,
|
||||
workspace: &WeakEntity<Workspace>,
|
||||
cx: &AsyncApp,
|
||||
) {
|
||||
let rejected_images = loaded_context.loaded_context.check_image_size_limits(model);
|
||||
if rejected_images.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let workspace = workspace.clone();
|
||||
let model_name = rejected_images[0].model_name.clone();
|
||||
let max_size = model.max_image_size();
|
||||
let count = rejected_images.len();
|
||||
let rejected_images = rejected_images.clone();
|
||||
|
||||
cx.update(|cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
let message = if max_size == 0 {
|
||||
Self::format_unsupported_images_message(&model_name, count)
|
||||
} else {
|
||||
Self::format_size_limit_message(
|
||||
&model_name,
|
||||
count,
|
||||
max_size,
|
||||
&rejected_images,
|
||||
)
|
||||
};
|
||||
|
||||
struct ImageRejectionToast;
|
||||
workspace.show_toast(
|
||||
Toast::new(NotificationId::unique::<ImageRejectionToast>(), message),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn format_unsupported_images_message(model_name: &str, count: usize) -> String {
|
||||
let plural = if count > 1 { "s" } else { "" };
|
||||
format!(
|
||||
"{} does not support image attachments. {} image{} will be excluded from your message.",
|
||||
model_name, count, plural
|
||||
)
|
||||
}
|
||||
|
||||
fn format_size_limit_message(
|
||||
model_name: &str,
|
||||
count: usize,
|
||||
max_size: u64,
|
||||
rejected_images: &[agent::context::RejectedImage],
|
||||
) -> String {
|
||||
let plural = if count > 1 { "s" } else { "" };
|
||||
let max_size_mb = max_size as f64 / 1_048_576.0;
|
||||
|
||||
// If only one image, show its specific size
|
||||
if count == 1 {
|
||||
let image_size_mb = rejected_images[0].size as f64 / 1_048_576.0;
|
||||
format!(
|
||||
"Image ({:.1} MB) exceeds {}'s {:.1} MB size limit and will be excluded.",
|
||||
image_size_mb, model_name, max_size_mb
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{} image{} exceeded {}'s {:.1} MB size limit and will be excluded.",
|
||||
count, plural, model_name, max_size_mb
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_editing(cx);
|
||||
@@ -1330,9 +1250,7 @@ impl MessageEditor {
|
||||
self.thread
|
||||
.read(cx)
|
||||
.configured_model()
|
||||
.map_or(false, |model| {
|
||||
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
|
||||
}
|
||||
|
||||
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
|
||||
|
||||
@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
@@ -356,7 +356,7 @@ pub async fn complete(
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(AnthropicError::HttpSend)?;
|
||||
let status = response.status();
|
||||
let status_code = response.status();
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
@@ -364,12 +364,12 @@ pub async fn complete(
|
||||
.await
|
||||
.map_err(AnthropicError::ReadResponse)?;
|
||||
|
||||
if status.is_success() {
|
||||
if status_code.is_success() {
|
||||
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
|
||||
} else {
|
||||
Err(AnthropicError::HttpResponseError {
|
||||
status: status.as_u16(),
|
||||
body,
|
||||
status_code,
|
||||
message: body,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -444,11 +444,7 @@ impl RateLimitInfo {
|
||||
}
|
||||
|
||||
Self {
|
||||
retry_after: headers
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_secs),
|
||||
retry_after: parse_retry_after(headers),
|
||||
requests: RateLimit::from_headers("requests", headers).ok(),
|
||||
tokens: RateLimit::from_headers("tokens", headers).ok(),
|
||||
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
|
||||
@@ -457,6 +453,17 @@ impl RateLimitInfo {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
|
||||
/// seconds). Note that other services might specify an HTTP date or some other format for this
|
||||
/// header. Returns `None` if the header is not present or cannot be parsed.
|
||||
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
|
||||
headers
|
||||
.get("retry-after")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_secs)
|
||||
}
|
||||
|
||||
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
|
||||
Ok(headers
|
||||
.get(key)
|
||||
@@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
})
|
||||
.boxed();
|
||||
Ok((stream, Some(rate_limits)))
|
||||
} else if response.status().as_u16() == 529 {
|
||||
Err(AnthropicError::ServerOverloaded {
|
||||
retry_after: rate_limits.retry_after,
|
||||
})
|
||||
} else if let Some(retry_after) = rate_limits.retry_after {
|
||||
Err(AnthropicError::RateLimit { retry_after })
|
||||
} else {
|
||||
@@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
|
||||
|
||||
match serde_json::from_str::<Event>(&body) {
|
||||
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
||||
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
|
||||
Err(_) => Err(AnthropicError::HttpResponseError {
|
||||
status: response.status().as_u16(),
|
||||
body: body,
|
||||
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
|
||||
status_code: response.status(),
|
||||
message: body,
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -801,16 +811,19 @@ pub enum AnthropicError {
|
||||
ReadResponse(io::Error),
|
||||
|
||||
/// HTTP error response from the API
|
||||
HttpResponseError { status: u16, body: String },
|
||||
HttpResponseError {
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Rate limit exceeded
|
||||
RateLimit { retry_after: Duration },
|
||||
|
||||
/// Server overloaded
|
||||
ServerOverloaded { retry_after: Option<Duration> },
|
||||
|
||||
/// API returned an error response
|
||||
ApiError(ApiError),
|
||||
|
||||
/// Unexpected response format
|
||||
UnexpectedResponseFormat(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Error)]
|
||||
|
||||
@@ -2140,7 +2140,8 @@ impl AssistantContext {
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(_) |
|
||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ pub use find_path_tool::FindPathToolInput;
|
||||
pub use grep_tool::{GrepTool, GrepToolInput};
|
||||
pub use open_tool::OpenTool;
|
||||
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
|
||||
pub use schema::root_schema_for;
|
||||
pub use terminal_tool::TerminalTool;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
@@ -29,6 +29,7 @@ use std::{
|
||||
path::Path,
|
||||
str::FromStr,
|
||||
sync::mpsc,
|
||||
time::Duration,
|
||||
};
|
||||
use util::path;
|
||||
|
||||
@@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
|
||||
match request().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
|
||||
Ok(err) => match err {
|
||||
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
|
||||
Ok(err) => match &err {
|
||||
LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
|
||||
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
|
||||
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
|
||||
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
|
||||
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
|
||||
eprintln!(
|
||||
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
|
||||
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
|
||||
);
|
||||
Timer::after(retry_after + jitter).await;
|
||||
continue;
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -314,7 +314,7 @@ impl Tool for GrepTool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assistant_tool::Tool;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -226,7 +226,7 @@ impl Tool for ListDirectoryTool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assistant_tool::Tool;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -289,7 +289,7 @@ impl Tool for ReadFileTool {
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal};
|
||||
use gpui::{TestAppContext, UpdateGlobal};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use project::{FakeFs, Project, WorktreeSettings};
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use schemars::{
|
||||
JsonSchema,
|
||||
schema::{RootSchema, Schema, SchemaObject},
|
||||
JsonSchema, Schema,
|
||||
generate::SchemaSettings,
|
||||
transform::{Transform, transform_subschemas},
|
||||
};
|
||||
|
||||
pub fn json_schema_for<T: JsonSchema>(
|
||||
@@ -13,7 +14,7 @@ pub fn json_schema_for<T: JsonSchema>(
|
||||
}
|
||||
|
||||
fn schema_to_json(
|
||||
schema: &RootSchema,
|
||||
schema: &Schema,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut value = serde_json::to_value(schema)?;
|
||||
@@ -21,58 +22,42 @@ fn schema_to_json(
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
|
||||
pub fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
let mut generator = match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
|
||||
schemars::r#gen::SchemaSettings::default()
|
||||
.with(|settings| {
|
||||
settings.meta_schema = None;
|
||||
settings.inline_subschemas = true;
|
||||
settings
|
||||
.visitors
|
||||
.push(Box::new(TransformToJsonSchemaSubsetVisitor));
|
||||
})
|
||||
.into_generator()
|
||||
}
|
||||
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
|
||||
// TODO: Gemini docs mention using a subset of OpenAPI 3, so this may benefit from using
|
||||
// `SchemaSettings::openapi3()`.
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::draft07()
|
||||
.with(|settings| {
|
||||
settings.meta_schema = None;
|
||||
settings.inline_subschemas = true;
|
||||
})
|
||||
.with_transform(ToJsonSchemaSubsetTransform)
|
||||
.into_generator(),
|
||||
};
|
||||
generator.root_schema_for::<T>()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TransformToJsonSchemaSubsetVisitor;
|
||||
struct ToJsonSchemaSubsetTransform;
|
||||
|
||||
impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
|
||||
fn visit_root_schema(&mut self, root: &mut RootSchema) {
|
||||
schemars::visit::visit_root_schema(self, root)
|
||||
}
|
||||
|
||||
fn visit_schema(&mut self, schema: &mut Schema) {
|
||||
schemars::visit::visit_schema(self, schema)
|
||||
}
|
||||
|
||||
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
|
||||
impl Transform for ToJsonSchemaSubsetTransform {
|
||||
fn transform(&mut self, schema: &mut Schema) {
|
||||
// Ensure that the type field is not an array, this happens when we use
|
||||
// Option<T>, the type will be [T, "null"].
|
||||
if let Some(instance_type) = schema.instance_type.take() {
|
||||
schema.instance_type = match instance_type {
|
||||
schemars::schema::SingleOrVec::Single(t) => {
|
||||
Some(schemars::schema::SingleOrVec::Single(t))
|
||||
if let Some(type_field) = schema.get_mut("type") {
|
||||
if let Some(types) = type_field.as_array() {
|
||||
if let Some(first_type) = types.first() {
|
||||
*type_field = first_type.clone();
|
||||
}
|
||||
schemars::schema::SingleOrVec::Vec(items) => items
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(schemars::schema::SingleOrVec::from),
|
||||
};
|
||||
}
|
||||
|
||||
// One of is not supported, use anyOf instead.
|
||||
if let Some(subschema) = schema.subschemas.as_mut() {
|
||||
if let Some(one_of) = subschema.one_of.take() {
|
||||
subschema.any_of = Some(one_of);
|
||||
}
|
||||
}
|
||||
|
||||
schemars::visit::visit_schema_object(self, schema)
|
||||
// oneOf is not supported, use anyOf instead
|
||||
if let Some(one_of) = schema.remove("oneOf") {
|
||||
schema.insert("anyOf".to_string(), one_of);
|
||||
}
|
||||
|
||||
transform_subschemas(self, schema);
|
||||
}
|
||||
}
|
||||
|
||||
47
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
47
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
@@ -0,0 +1,47 @@
|
||||
You are an expert text editor. Taking the following file as an input:
|
||||
|
||||
```{{path}}
|
||||
{{file_content}}
|
||||
```
|
||||
|
||||
Produce a series of edits following the given user instructions:
|
||||
|
||||
<user_instructions>
|
||||
{{instructions}}
|
||||
</user_instructions>
|
||||
|
||||
Your response must be a series of edits in the following format:
|
||||
|
||||
<edits>
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 2 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 2 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 3 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 3 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
Rules for editing:
|
||||
|
||||
- `old_text` represents full lines (including indentation) in the input file that will be replaced with `new_text`
|
||||
- It is crucial that `old_text` is unique and unambiguous.
|
||||
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
|
||||
- If you want to replace all occurrences, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
|
||||
- Don't explain why you made a change, just report the edits.
|
||||
- Make sure you follow the instructions carefully and thoroughly, avoid doing *less* or *more* than instructed.
|
||||
|
||||
<edits>
|
||||
@@ -25,5 +25,4 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
mod models;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{Context as _, Error, Result, anyhow};
|
||||
use anyhow::{Context, Error, Result, anyhow};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
pub use aws_sdk_bedrockruntime::types::{
|
||||
@@ -24,9 +21,10 @@ pub use bedrock::types::{
|
||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
|
||||
};
|
||||
use futures::stream::{self, BoxStream, Stream};
|
||||
use futures::stream::{self, BoxStream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Number, Value};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use crate::models::*;
|
||||
@@ -34,70 +32,59 @@ pub use crate::models::*;
|
||||
pub async fn stream_completion(
|
||||
client: bedrock::Client,
|
||||
request: Request,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
|
||||
handle
|
||||
.spawn(async move {
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into());
|
||||
let mut response = bedrock::Client::converse_stream(&client)
|
||||
.model_id(request.model.clone())
|
||||
.set_messages(request.messages.into());
|
||||
|
||||
if let Some(Thinking::Enabled {
|
||||
budget_tokens: Some(budget_tokens),
|
||||
}) = request.thinking
|
||||
{
|
||||
response =
|
||||
response.additional_model_request_fields(Document::Object(HashMap::from([(
|
||||
"thinking".to_string(),
|
||||
Document::from(HashMap::from([
|
||||
("type".to_string(), Document::String("enabled".to_string())),
|
||||
(
|
||||
"budget_tokens".to_string(),
|
||||
Document::Number(AwsNumber::PosInt(budget_tokens)),
|
||||
),
|
||||
])),
|
||||
)])));
|
||||
}
|
||||
if let Some(Thinking::Enabled {
|
||||
budget_tokens: Some(budget_tokens),
|
||||
}) = request.thinking
|
||||
{
|
||||
let thinking_config = HashMap::from([
|
||||
("type".to_string(), Document::String("enabled".to_string())),
|
||||
(
|
||||
"budget_tokens".to_string(),
|
||||
Document::Number(AwsNumber::PosInt(budget_tokens)),
|
||||
),
|
||||
]);
|
||||
response = response.additional_model_request_fields(Document::Object(HashMap::from([(
|
||||
"thinking".to_string(),
|
||||
Document::from(thinking_config),
|
||||
)])));
|
||||
}
|
||||
|
||||
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
|
||||
response = response.set_tool_config(request.tools);
|
||||
}
|
||||
if request
|
||||
.tools
|
||||
.as_ref()
|
||||
.map_or(false, |t| !t.tools.is_empty())
|
||||
{
|
||||
response = response.set_tool_config(request.tools);
|
||||
}
|
||||
|
||||
let response = response.send().await;
|
||||
let output = response
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send API request to Bedrock");
|
||||
|
||||
match response {
|
||||
Ok(output) => {
|
||||
let stream: Pin<
|
||||
Box<
|
||||
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
|
||||
+ Send,
|
||||
>,
|
||||
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some(({ Ok(output) }, stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => {
|
||||
Some((
|
||||
// TODO: Figure out how we can capture Throttling Exceptions
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
stream,
|
||||
))
|
||||
}
|
||||
}
|
||||
}));
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => Err(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
let stream = Box::pin(stream::unfold(
|
||||
output?.stream,
|
||||
move |mut stream| async move {
|
||||
match stream.recv().await {
|
||||
Ok(Some(output)) => Some((Ok(output), stream)),
|
||||
Ok(None) => None,
|
||||
Err(err) => Some((
|
||||
Err(BedrockError::ClientError(anyhow!(
|
||||
"{:?}",
|
||||
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
|
||||
))),
|
||||
stream,
|
||||
)),
|
||||
}
|
||||
})
|
||||
.await
|
||||
.context("spawning a task")?
|
||||
},
|
||||
));
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub fn aws_document_to_value(document: &Document) -> Value {
|
||||
|
||||
@@ -12,7 +12,6 @@ pub struct CallSettings {
|
||||
|
||||
/// Configuration of voice calls in Zed.
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct CallSettingsContent {
|
||||
/// Whether the microphone should be muted when joining a channel or a call.
|
||||
///
|
||||
|
||||
@@ -22,9 +22,7 @@ use gpui::{
|
||||
use language::{
|
||||
Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig,
|
||||
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
|
||||
language_settings::{
|
||||
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
|
||||
},
|
||||
language_settings::{AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter},
|
||||
tree_sitter_rust, tree_sitter_typescript,
|
||||
};
|
||||
use lsp::{LanguageServerId, OneOf};
|
||||
@@ -4591,15 +4589,13 @@ async fn test_formatting_buffer(
|
||||
cx_a.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
|
||||
vec![Formatter::External {
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::External {
|
||||
command: "awk".into(),
|
||||
arguments: Some(
|
||||
vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
|
||||
),
|
||||
}]
|
||||
.into(),
|
||||
)));
|
||||
}]));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4699,9 +4695,10 @@ async fn test_prettier_formatting_buffer(
|
||||
cx_b.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
|
||||
vec![Formatter::LanguageServer { name: None }].into(),
|
||||
)));
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]));
|
||||
file.defaults.prettier = Some(PrettierSettings {
|
||||
allowed: true,
|
||||
..PrettierSettings::default()
|
||||
|
||||
@@ -6,16 +6,12 @@ use debugger_ui::debugger_panel::DebugPanel;
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs as _, RemoveOptions};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{
|
||||
AppContext as _, BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _,
|
||||
VisualContext,
|
||||
};
|
||||
use gpui::{BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _, VisualContext};
|
||||
use http_client::BlockedHttpClient;
|
||||
use language::{
|
||||
FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LanguageRegistry,
|
||||
language_settings::{
|
||||
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
|
||||
language_settings,
|
||||
AllLanguageSettings, Formatter, PrettierSettings, SelectedFormatter, language_settings,
|
||||
},
|
||||
tree_sitter_typescript,
|
||||
};
|
||||
@@ -505,9 +501,10 @@ async fn test_ssh_collaboration_formatting_with_prettier(
|
||||
cx_b.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
|
||||
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
|
||||
vec![Formatter::LanguageServer { name: None }].into(),
|
||||
)));
|
||||
file.defaults.formatter =
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]));
|
||||
file.defaults.prettier = Some(PrettierSettings {
|
||||
allowed: true,
|
||||
..PrettierSettings::default()
|
||||
|
||||
@@ -28,7 +28,6 @@ pub struct ChatPanelSettings {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct ChatPanelSettingsContent {
|
||||
/// When to show the panel button in the status bar.
|
||||
///
|
||||
@@ -52,7 +51,6 @@ pub struct NotificationPanelSettings {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct PanelSettingsContent {
|
||||
/// Whether to show the panel button in the status bar.
|
||||
///
|
||||
@@ -69,7 +67,6 @@ pub struct PanelSettingsContent {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct MessageEditorSettings {
|
||||
/// Whether to automatically replace emoji shortcodes with emoji characters.
|
||||
/// For example: typing `:wave:` gets replaced with `👋`.
|
||||
|
||||
@@ -41,7 +41,7 @@ pub struct CommandPalette {
|
||||
/// Removes subsequent whitespace characters and double colons from the query.
|
||||
///
|
||||
/// This improves the likelihood of a match by either humanized name or keymap-style name.
|
||||
fn normalize_query(input: &str) -> String {
|
||||
pub fn normalize_action_query(input: &str) -> String {
|
||||
let mut result = String::with_capacity(input.len());
|
||||
let mut last_char = None;
|
||||
|
||||
@@ -297,7 +297,7 @@ impl PickerDelegate for CommandPaletteDelegate {
|
||||
let mut commands = self.all_commands.clone();
|
||||
let hit_counts = self.hit_counts();
|
||||
let executor = cx.background_executor().clone();
|
||||
let query = normalize_query(query.as_str());
|
||||
let query = normalize_action_query(query.as_str());
|
||||
async move {
|
||||
commands.sort_by_key(|action| {
|
||||
(
|
||||
@@ -311,29 +311,17 @@ impl PickerDelegate for CommandPaletteDelegate {
|
||||
.enumerate()
|
||||
.map(|(ix, command)| StringMatchCandidate::new(ix, &command.name))
|
||||
.collect::<Vec<_>>();
|
||||
let matches = if query.is_empty() {
|
||||
candidates
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, candidate)| StringMatch {
|
||||
candidate_id: index,
|
||||
string: candidate.string,
|
||||
positions: Vec::new(),
|
||||
score: 0.0,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
true,
|
||||
true,
|
||||
10000,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
let matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
true,
|
||||
true,
|
||||
10000,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
tx.send((commands, matches)).await.log_err();
|
||||
}
|
||||
@@ -422,8 +410,8 @@ impl PickerDelegate for CommandPaletteDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let r#match = self.matches.get(ix)?;
|
||||
let command = self.commands.get(r#match.candidate_id)?;
|
||||
let matching_command = self.matches.get(ix)?;
|
||||
let command = self.commands.get(matching_command.candidate_id)?;
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
@@ -436,7 +424,7 @@ impl PickerDelegate for CommandPaletteDelegate {
|
||||
.justify_between()
|
||||
.child(HighlightedLabel::new(
|
||||
command.name.clone(),
|
||||
r#match.positions.clone(),
|
||||
matching_command.positions.clone(),
|
||||
))
|
||||
.children(KeyBinding::for_action_in(
|
||||
&*command.action,
|
||||
@@ -512,19 +500,28 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_normalize_query() {
|
||||
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
|
||||
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
|
||||
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
|
||||
assert_eq!(
|
||||
normalize_query("editor::GoToDefinition"),
|
||||
normalize_action_query("editor: backspace"),
|
||||
"editor: backspace"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_action_query("editor: backspace"),
|
||||
"editor: backspace"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_action_query("editor: backspace"),
|
||||
"editor: backspace"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_action_query("editor::GoToDefinition"),
|
||||
"editor:GoToDefinition"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_query("editor::::GoToDefinition"),
|
||||
normalize_action_query("editor::::GoToDefinition"),
|
||||
"editor:GoToDefinition"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_query("editor: :GoToDefinition"),
|
||||
normalize_action_query("editor: :GoToDefinition"),
|
||||
"editor: :GoToDefinition"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use gpui::{AsyncApp, SharedString};
|
||||
pub use http_client::{HttpClient, github::latest_github_release};
|
||||
use language::{LanguageName, LanguageToolchainStore};
|
||||
use node_runtime::NodeRuntime;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::WorktreeId;
|
||||
use smol::fs::File;
|
||||
@@ -47,7 +48,10 @@ pub trait DapDelegate: Send + Sync + 'static {
|
||||
async fn shell_env(&self) -> collections::HashMap<String, String>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
|
||||
#[derive(
|
||||
Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize, JsonSchema,
|
||||
)]
|
||||
#[serde(transparent)]
|
||||
pub struct DebugAdapterName(pub SharedString);
|
||||
|
||||
impl Deref for DebugAdapterName {
|
||||
|
||||
@@ -22,17 +22,16 @@ impl CodeLldbDebugAdapter {
|
||||
async fn request_args(
|
||||
&self,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
task_definition: &DebugTaskDefinition,
|
||||
mut configuration: Value,
|
||||
label: &str,
|
||||
) -> Result<dap::StartDebuggingRequestArguments> {
|
||||
// CodeLLDB uses `name` for a terminal label.
|
||||
let mut configuration = task_definition.config.clone();
|
||||
|
||||
let obj = configuration
|
||||
.as_object_mut()
|
||||
.context("CodeLLDB is not a valid json object")?;
|
||||
|
||||
// CodeLLDB uses `name` for a terminal label.
|
||||
obj.entry("name")
|
||||
.or_insert(Value::String(String::from(task_definition.label.as_ref())));
|
||||
.or_insert(Value::String(String::from(label)));
|
||||
|
||||
obj.entry("cwd")
|
||||
.or_insert(delegate.worktree_root_path().to_string_lossy().into());
|
||||
@@ -361,17 +360,31 @@ impl DebugAdapter for CodeLldbDebugAdapter {
|
||||
self.path_to_codelldb.set(path.clone()).ok();
|
||||
command = Some(path);
|
||||
};
|
||||
|
||||
let mut json_config = config.config.clone();
|
||||
Ok(DebugAdapterBinary {
|
||||
command: Some(command.unwrap()),
|
||||
cwd: Some(delegate.worktree_root_path().to_path_buf()),
|
||||
arguments: user_args.unwrap_or_else(|| {
|
||||
vec![
|
||||
"--settings".into(),
|
||||
json!({"sourceLanguages": ["cpp", "rust"]}).to_string(),
|
||||
]
|
||||
if let Some(config) = json_config.as_object_mut()
|
||||
&& let Some(source_languages) = config.get("sourceLanguages").filter(|value| {
|
||||
value
|
||||
.as_array()
|
||||
.map_or(false, |array| array.iter().all(Value::is_string))
|
||||
})
|
||||
{
|
||||
let ret = vec![
|
||||
"--settings".into(),
|
||||
json!({"sourceLanguages": source_languages}).to_string(),
|
||||
];
|
||||
config.remove("sourceLanguages");
|
||||
ret
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}),
|
||||
request_args: self.request_args(delegate, &config).await?,
|
||||
request_args: self
|
||||
.request_args(delegate, json_config, &config.label)
|
||||
.await?,
|
||||
envs: HashMap::default(),
|
||||
connection: None,
|
||||
})
|
||||
|
||||
@@ -282,6 +282,10 @@ impl DebugAdapter for JsDebugAdapter {
|
||||
"description": "Automatically stop program after launch",
|
||||
"default": false
|
||||
},
|
||||
"attachSimplePort": {
|
||||
"type": "number",
|
||||
"description": "If set, attaches to the process via the given port. This is generally no longer necessary for Node.js programs and loses the ability to debug child processes, but can be useful in more esoteric scenarios such as with Deno and Docker launches. If set to 0, a random port will be chosen and --inspect-brk added to the launch arguments automatically."
|
||||
},
|
||||
"runtimeExecutable": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Runtime to use, an absolute path or the name of a runtime available on PATH",
|
||||
|
||||
@@ -434,9 +434,14 @@ impl LogStore {
|
||||
|
||||
fn clean_sessions(&mut self, cx: &mut Context<Self>) {
|
||||
self.projects.values_mut().for_each(|project| {
|
||||
project
|
||||
.debug_sessions
|
||||
.retain(|_, session| !session.is_terminated);
|
||||
let mut allowed_terminated_sessions = 10u32;
|
||||
project.debug_sessions.retain(|_, session| {
|
||||
if !session.is_terminated {
|
||||
return true;
|
||||
}
|
||||
allowed_terminated_sessions = allowed_terminated_sessions.saturating_sub(1);
|
||||
allowed_terminated_sessions > 0
|
||||
});
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
|
||||
@@ -900,7 +900,7 @@ impl RunningState {
|
||||
|
||||
|
||||
let config_is_valid = request_type.is_ok();
|
||||
|
||||
let mut extra_config = Value::Null;
|
||||
let build_output = if let Some(build) = build {
|
||||
let (task_template, locator_name) = match build {
|
||||
BuildTaskDefinition::Template {
|
||||
@@ -930,6 +930,7 @@ impl RunningState {
|
||||
};
|
||||
|
||||
let locator_name = if let Some(locator_name) = locator_name {
|
||||
extra_config = config.clone();
|
||||
debug_assert!(!config_is_valid);
|
||||
Some(locator_name)
|
||||
} else if !config_is_valid {
|
||||
@@ -945,6 +946,7 @@ impl RunningState {
|
||||
});
|
||||
if let Ok(t) = task {
|
||||
t.await.and_then(|scenario| {
|
||||
extra_config = scenario.config;
|
||||
match scenario.build {
|
||||
Some(BuildTaskDefinition::Template {
|
||||
locator_name, ..
|
||||
@@ -1008,13 +1010,13 @@ impl RunningState {
|
||||
if !exit_status.success() {
|
||||
anyhow::bail!("Build failed");
|
||||
}
|
||||
Some((task.resolved.clone(), locator_name))
|
||||
Some((task.resolved.clone(), locator_name, extra_config))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if config_is_valid {
|
||||
} else if let Some((task, locator_name)) = build_output {
|
||||
} else if let Some((task, locator_name, extra_config)) = build_output {
|
||||
let locator_name =
|
||||
locator_name.with_context(|| {
|
||||
format!("Could not find a valid locator for a build task and configure is invalid with error: {}", request_type.err()
|
||||
@@ -1039,6 +1041,8 @@ impl RunningState {
|
||||
.with_context(|| anyhow!("{}: is not a valid adapter name", &adapter))?.config_from_zed_format(zed_config)
|
||||
.await?;
|
||||
config = scenario.config;
|
||||
util::merge_non_null_json_value_into(extra_config, &mut config);
|
||||
|
||||
Self::substitute_variables_in_config(&mut config, &task_context);
|
||||
} else {
|
||||
let Err(e) = request_type else {
|
||||
|
||||
@@ -61,6 +61,7 @@ parking_lot.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
rpc.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -11541,66 +11541,90 @@ impl Editor {
|
||||
let language_settings = buffer.language_settings_at(selection.head(), cx);
|
||||
let language_scope = buffer.language_scope_at(selection.head());
|
||||
|
||||
let indent_and_prefix_for_row =
|
||||
|row: u32| -> (IndentSize, Option<String>, Option<String>) {
|
||||
let indent = buffer.indent_size_for_line(MultiBufferRow(row));
|
||||
let (comment_prefix, rewrap_prefix) =
|
||||
if let Some(language_scope) = &language_scope {
|
||||
let indent_end = Point::new(row, indent.len);
|
||||
let comment_prefix = language_scope
|
||||
.line_comment_prefixes()
|
||||
.iter()
|
||||
.find(|prefix| buffer.contains_str_at(indent_end, prefix))
|
||||
.map(|prefix| prefix.to_string());
|
||||
let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row)));
|
||||
let line_text_after_indent = buffer
|
||||
.text_for_range(indent_end..line_end)
|
||||
.collect::<String>();
|
||||
let rewrap_prefix = language_scope
|
||||
.rewrap_prefixes()
|
||||
.iter()
|
||||
.find_map(|prefix_regex| {
|
||||
prefix_regex.find(&line_text_after_indent).map(|mat| {
|
||||
if mat.start() == 0 {
|
||||
Some(mat.as_str().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.flatten();
|
||||
(comment_prefix, rewrap_prefix)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
(indent, comment_prefix, rewrap_prefix)
|
||||
};
|
||||
|
||||
let mut ranges = Vec::new();
|
||||
let mut current_range_start = first_row;
|
||||
let from_empty_selection = selection.is_empty();
|
||||
|
||||
let mut current_range_start = first_row;
|
||||
let mut prev_row = first_row;
|
||||
let mut prev_indent = buffer.indent_size_for_line(MultiBufferRow(first_row));
|
||||
let mut prev_comment_prefix = if let Some(language_scope) = &language_scope {
|
||||
let indent = buffer.indent_size_for_line(MultiBufferRow(first_row));
|
||||
let indent_end = Point::new(first_row, indent.len);
|
||||
language_scope
|
||||
.line_comment_prefixes()
|
||||
.iter()
|
||||
.find(|prefix| buffer.contains_str_at(indent_end, prefix))
|
||||
.cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (
|
||||
mut current_range_indent,
|
||||
mut current_range_comment_prefix,
|
||||
mut current_range_rewrap_prefix,
|
||||
) = indent_and_prefix_for_row(first_row);
|
||||
|
||||
for row in non_blank_rows_iter.skip(1) {
|
||||
let has_paragraph_break = row > prev_row + 1;
|
||||
|
||||
let row_indent = buffer.indent_size_for_line(MultiBufferRow(row));
|
||||
let row_comment_prefix = if let Some(language_scope) = &language_scope {
|
||||
let indent = buffer.indent_size_for_line(MultiBufferRow(row));
|
||||
let indent_end = Point::new(row, indent.len);
|
||||
language_scope
|
||||
.line_comment_prefixes()
|
||||
.iter()
|
||||
.find(|prefix| buffer.contains_str_at(indent_end, prefix))
|
||||
.cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (row_indent, row_comment_prefix, row_rewrap_prefix) =
|
||||
indent_and_prefix_for_row(row);
|
||||
|
||||
let has_boundary_change =
|
||||
row_indent != prev_indent || row_comment_prefix != prev_comment_prefix;
|
||||
let has_indent_change = row_indent != current_range_indent;
|
||||
let has_comment_change = row_comment_prefix != current_range_comment_prefix;
|
||||
|
||||
let has_boundary_change = has_comment_change
|
||||
|| row_rewrap_prefix.is_some()
|
||||
|| (has_indent_change && current_range_comment_prefix.is_some());
|
||||
|
||||
if has_paragraph_break || has_boundary_change {
|
||||
ranges.push((
|
||||
language_settings.clone(),
|
||||
Point::new(current_range_start, 0)
|
||||
..Point::new(prev_row, buffer.line_len(MultiBufferRow(prev_row))),
|
||||
prev_indent,
|
||||
prev_comment_prefix.clone(),
|
||||
current_range_indent,
|
||||
current_range_comment_prefix.clone(),
|
||||
current_range_rewrap_prefix.clone(),
|
||||
from_empty_selection,
|
||||
));
|
||||
current_range_start = row;
|
||||
current_range_indent = row_indent;
|
||||
current_range_comment_prefix = row_comment_prefix;
|
||||
current_range_rewrap_prefix = row_rewrap_prefix;
|
||||
}
|
||||
|
||||
prev_row = row;
|
||||
prev_indent = row_indent;
|
||||
prev_comment_prefix = row_comment_prefix;
|
||||
}
|
||||
|
||||
ranges.push((
|
||||
language_settings.clone(),
|
||||
Point::new(current_range_start, 0)
|
||||
..Point::new(prev_row, buffer.line_len(MultiBufferRow(prev_row))),
|
||||
prev_indent,
|
||||
prev_comment_prefix,
|
||||
current_range_indent,
|
||||
current_range_comment_prefix,
|
||||
current_range_rewrap_prefix,
|
||||
from_empty_selection,
|
||||
));
|
||||
|
||||
@@ -11610,8 +11634,14 @@ impl Editor {
|
||||
let mut edits = Vec::new();
|
||||
let mut rewrapped_row_ranges = Vec::<RangeInclusive<u32>>::new();
|
||||
|
||||
for (language_settings, wrap_range, indent_size, comment_prefix, from_empty_selection) in
|
||||
wrap_ranges
|
||||
for (
|
||||
language_settings,
|
||||
wrap_range,
|
||||
indent_size,
|
||||
comment_prefix,
|
||||
rewrap_prefix,
|
||||
from_empty_selection,
|
||||
) in wrap_ranges
|
||||
{
|
||||
let mut start_row = wrap_range.start.row;
|
||||
let mut end_row = wrap_range.end.row;
|
||||
@@ -11627,12 +11657,16 @@ impl Editor {
|
||||
|
||||
let tab_size = language_settings.tab_size;
|
||||
|
||||
let mut line_prefix = indent_size.chars().collect::<String>();
|
||||
let indent_prefix = indent_size.chars().collect::<String>();
|
||||
let mut line_prefix = indent_prefix.clone();
|
||||
let mut inside_comment = false;
|
||||
if let Some(prefix) = &comment_prefix {
|
||||
line_prefix.push_str(prefix);
|
||||
inside_comment = true;
|
||||
}
|
||||
if let Some(prefix) = &rewrap_prefix {
|
||||
line_prefix.push_str(prefix);
|
||||
}
|
||||
|
||||
let allow_rewrap_based_on_language = match language_settings.allow_rewrap {
|
||||
RewrapBehavior::InComments => inside_comment,
|
||||
@@ -11679,12 +11713,18 @@ impl Editor {
|
||||
let selection_text = buffer.text_for_range(start..end).collect::<String>();
|
||||
let Some(lines_without_prefixes) = selection_text
|
||||
.lines()
|
||||
.map(|line| {
|
||||
line.strip_prefix(&line_prefix)
|
||||
.or_else(|| line.trim_start().strip_prefix(&line_prefix.trim_start()))
|
||||
.with_context(|| {
|
||||
format!("line did not start with prefix {line_prefix:?}: {line:?}")
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(ix, line)| {
|
||||
let line_trimmed = line.trim_start();
|
||||
if rewrap_prefix.is_some() && ix > 0 {
|
||||
Ok(line_trimmed)
|
||||
} else {
|
||||
line_trimmed
|
||||
.strip_prefix(&line_prefix.trim_start())
|
||||
.with_context(|| {
|
||||
format!("line did not start with prefix {line_prefix:?}: {line:?}")
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.log_err()
|
||||
@@ -11697,8 +11737,16 @@ impl Editor {
|
||||
.language_settings_at(Point::new(start_row, 0), cx)
|
||||
.preferred_line_length as usize
|
||||
});
|
||||
|
||||
let subsequent_lines_prefix = if let Some(rewrap_prefix_str) = &rewrap_prefix {
|
||||
format!("{}{}", indent_prefix, " ".repeat(rewrap_prefix_str.len()))
|
||||
} else {
|
||||
line_prefix.clone()
|
||||
};
|
||||
|
||||
let wrapped_text = wrap_with_prefix(
|
||||
line_prefix,
|
||||
subsequent_lines_prefix,
|
||||
lines_without_prefixes.join("\n"),
|
||||
wrap_column,
|
||||
tab_size,
|
||||
@@ -21200,18 +21248,22 @@ fn test_word_breaking_tokenizer() {
|
||||
}
|
||||
|
||||
fn wrap_with_prefix(
|
||||
line_prefix: String,
|
||||
first_line_prefix: String,
|
||||
subsequent_lines_prefix: String,
|
||||
unwrapped_text: String,
|
||||
wrap_column: usize,
|
||||
tab_size: NonZeroU32,
|
||||
preserve_existing_whitespace: bool,
|
||||
) -> String {
|
||||
let line_prefix_len = char_len_with_expanded_tabs(0, &line_prefix, tab_size);
|
||||
let first_line_prefix_len = char_len_with_expanded_tabs(0, &first_line_prefix, tab_size);
|
||||
let subsequent_lines_prefix_len =
|
||||
char_len_with_expanded_tabs(0, &subsequent_lines_prefix, tab_size);
|
||||
let mut wrapped_text = String::new();
|
||||
let mut current_line = line_prefix.clone();
|
||||
let mut current_line = first_line_prefix.clone();
|
||||
let mut is_first_line = true;
|
||||
|
||||
let tokenizer = WordBreakingTokenizer::new(&unwrapped_text);
|
||||
let mut current_line_len = line_prefix_len;
|
||||
let mut current_line_len = first_line_prefix_len;
|
||||
let mut in_whitespace = false;
|
||||
for token in tokenizer {
|
||||
let have_preceding_whitespace = in_whitespace;
|
||||
@@ -21221,13 +21273,19 @@ fn wrap_with_prefix(
|
||||
grapheme_len,
|
||||
} => {
|
||||
in_whitespace = false;
|
||||
let current_prefix_len = if is_first_line {
|
||||
first_line_prefix_len
|
||||
} else {
|
||||
subsequent_lines_prefix_len
|
||||
};
|
||||
if current_line_len + grapheme_len > wrap_column
|
||||
&& current_line_len != line_prefix_len
|
||||
&& current_line_len != current_prefix_len
|
||||
{
|
||||
wrapped_text.push_str(current_line.trim_end());
|
||||
wrapped_text.push('\n');
|
||||
current_line.truncate(line_prefix.len());
|
||||
current_line_len = line_prefix_len;
|
||||
is_first_line = false;
|
||||
current_line = subsequent_lines_prefix.clone();
|
||||
current_line_len = subsequent_lines_prefix_len;
|
||||
}
|
||||
current_line.push_str(token);
|
||||
current_line_len += grapheme_len;
|
||||
@@ -21244,32 +21302,46 @@ fn wrap_with_prefix(
|
||||
token = " ";
|
||||
grapheme_len = 1;
|
||||
}
|
||||
let current_prefix_len = if is_first_line {
|
||||
first_line_prefix_len
|
||||
} else {
|
||||
subsequent_lines_prefix_len
|
||||
};
|
||||
if current_line_len + grapheme_len > wrap_column {
|
||||
wrapped_text.push_str(current_line.trim_end());
|
||||
wrapped_text.push('\n');
|
||||
current_line.truncate(line_prefix.len());
|
||||
current_line_len = line_prefix_len;
|
||||
} else if current_line_len != line_prefix_len || preserve_existing_whitespace {
|
||||
is_first_line = false;
|
||||
current_line = subsequent_lines_prefix.clone();
|
||||
current_line_len = subsequent_lines_prefix_len;
|
||||
} else if current_line_len != current_prefix_len || preserve_existing_whitespace {
|
||||
current_line.push_str(token);
|
||||
current_line_len += grapheme_len;
|
||||
}
|
||||
}
|
||||
WordBreakToken::Newline => {
|
||||
in_whitespace = true;
|
||||
let current_prefix_len = if is_first_line {
|
||||
first_line_prefix_len
|
||||
} else {
|
||||
subsequent_lines_prefix_len
|
||||
};
|
||||
if preserve_existing_whitespace {
|
||||
wrapped_text.push_str(current_line.trim_end());
|
||||
wrapped_text.push('\n');
|
||||
current_line.truncate(line_prefix.len());
|
||||
current_line_len = line_prefix_len;
|
||||
is_first_line = false;
|
||||
current_line = subsequent_lines_prefix.clone();
|
||||
current_line_len = subsequent_lines_prefix_len;
|
||||
} else if have_preceding_whitespace {
|
||||
continue;
|
||||
} else if current_line_len + 1 > wrap_column && current_line_len != line_prefix_len
|
||||
} else if current_line_len + 1 > wrap_column
|
||||
&& current_line_len != current_prefix_len
|
||||
{
|
||||
wrapped_text.push_str(current_line.trim_end());
|
||||
wrapped_text.push('\n');
|
||||
current_line.truncate(line_prefix.len());
|
||||
current_line_len = line_prefix_len;
|
||||
} else if current_line_len != line_prefix_len {
|
||||
is_first_line = false;
|
||||
current_line = subsequent_lines_prefix.clone();
|
||||
current_line_len = subsequent_lines_prefix_len;
|
||||
} else if current_line_len != current_prefix_len {
|
||||
current_line.push(' ');
|
||||
current_line_len += 1;
|
||||
}
|
||||
@@ -21287,6 +21359,7 @@ fn wrap_with_prefix(
|
||||
fn test_wrap_with_prefix() {
|
||||
assert_eq!(
|
||||
wrap_with_prefix(
|
||||
"# ".to_string(),
|
||||
"# ".to_string(),
|
||||
"abcdefg".to_string(),
|
||||
4,
|
||||
@@ -21297,6 +21370,7 @@ fn test_wrap_with_prefix() {
|
||||
);
|
||||
assert_eq!(
|
||||
wrap_with_prefix(
|
||||
"".to_string(),
|
||||
"".to_string(),
|
||||
"\thello world".to_string(),
|
||||
8,
|
||||
@@ -21307,6 +21381,7 @@ fn test_wrap_with_prefix() {
|
||||
);
|
||||
assert_eq!(
|
||||
wrap_with_prefix(
|
||||
"// ".to_string(),
|
||||
"// ".to_string(),
|
||||
"xx \nyy zz aa bb cc".to_string(),
|
||||
12,
|
||||
@@ -21317,6 +21392,7 @@ fn test_wrap_with_prefix() {
|
||||
);
|
||||
assert_eq!(
|
||||
wrap_with_prefix(
|
||||
String::new(),
|
||||
String::new(),
|
||||
"这是什么 \n 钢笔".to_string(),
|
||||
3,
|
||||
|
||||
@@ -378,7 +378,6 @@ pub enum SnippetSortOrder {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct EditorSettingsContent {
|
||||
/// Whether the cursor blinks in the editor.
|
||||
///
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use gpui::{App, FontFeatures, FontWeight};
|
||||
use project::project_settings::{InlineBlameSettings, ProjectSettings};
|
||||
use settings::{EditableSettingControl, Settings};
|
||||
use theme::{FontFamilyCache, ThemeSettings};
|
||||
use theme::{FontFamilyCache, FontFamilyName, ThemeSettings};
|
||||
use ui::{
|
||||
CheckboxWithLabel, ContextMenu, DropdownMenu, NumericStepper, SettingsContainer, SettingsGroup,
|
||||
prelude::*,
|
||||
@@ -75,7 +75,7 @@ impl EditableSettingControl for BufferFontFamilyControl {
|
||||
value: Self::Value,
|
||||
_cx: &App,
|
||||
) {
|
||||
settings.buffer_font_family = Some(value.to_string());
|
||||
settings.buffer_font_family = Some(FontFamilyName(value.into()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ use language::{
|
||||
},
|
||||
tree_sitter_python,
|
||||
};
|
||||
use language_settings::{Formatter, FormatterList, IndentGuideSettings};
|
||||
use language_settings::{Formatter, IndentGuideSettings};
|
||||
use lsp::CompletionParams;
|
||||
use multi_buffer::{IndentGuide, PathKey};
|
||||
use parking_lot::Mutex;
|
||||
@@ -3567,7 +3567,7 @@ async fn test_indent_outdent_with_hard_tabs(cx: &mut TestAppContext) {
|
||||
#[gpui::test]
|
||||
fn test_indent_outdent_with_excerpts(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.languages.extend([
|
||||
settings.languages.0.extend([
|
||||
(
|
||||
"TOML".into(),
|
||||
LanguageSettingsContent {
|
||||
@@ -5145,7 +5145,7 @@ fn test_transpose(cx: &mut TestAppContext) {
|
||||
#[gpui::test]
|
||||
async fn test_rewrap(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.languages.extend([
|
||||
settings.languages.0.extend([
|
||||
(
|
||||
"Markdown".into(),
|
||||
LanguageSettingsContent {
|
||||
@@ -5210,6 +5210,10 @@ async fn test_rewrap(cx: &mut TestAppContext) {
|
||||
let markdown_language = Arc::new(Language::new(
|
||||
LanguageConfig {
|
||||
name: "Markdown".into(),
|
||||
rewrap_prefixes: vec![
|
||||
regex::Regex::new("\\d+\\.\\s+").unwrap(),
|
||||
regex::Regex::new("[-*+]\\s+").unwrap(),
|
||||
],
|
||||
..LanguageConfig::default()
|
||||
},
|
||||
None,
|
||||
@@ -5372,7 +5376,82 @@ async fn test_rewrap(cx: &mut TestAppContext) {
|
||||
A long long long line of markdown text
|
||||
to wrap.ˇ
|
||||
"},
|
||||
markdown_language,
|
||||
markdown_language.clone(),
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
// Test that rewrapping boundary works and preserves relative indent for Markdown documents
|
||||
assert_rewrap(
|
||||
indoc! {"
|
||||
«1. This is a numbered list item that is very long and needs to be wrapped properly.
|
||||
2. This is a numbered list item that is very long and needs to be wrapped properly.
|
||||
- This is an unordered list item that is also very long and should not merge with the numbered item.ˇ»
|
||||
"},
|
||||
indoc! {"
|
||||
«1. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
2. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
- This is an unordered list item that is
|
||||
also very long and should not merge
|
||||
with the numbered item.ˇ»
|
||||
"},
|
||||
markdown_language.clone(),
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
// Test that rewrapping add indents for rewrapping boundary if not exists already.
|
||||
assert_rewrap(
|
||||
indoc! {"
|
||||
«1. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
2. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
- This is an unordered list item that is
|
||||
also very long and should not merge with
|
||||
the numbered item.ˇ»
|
||||
"},
|
||||
indoc! {"
|
||||
«1. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
2. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
- This is an unordered list item that is
|
||||
also very long and should not merge
|
||||
with the numbered item.ˇ»
|
||||
"},
|
||||
markdown_language.clone(),
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
// Test that rewrapping maintain indents even when they already exists.
|
||||
assert_rewrap(
|
||||
indoc! {"
|
||||
«1. This is a numbered list
|
||||
item that is very long and needs to be wrapped properly.
|
||||
2. This is a numbered list
|
||||
item that is very long and needs to be wrapped properly.
|
||||
- This is an unordered list item that is also very long and
|
||||
should not merge with the numbered item.ˇ»
|
||||
"},
|
||||
indoc! {"
|
||||
«1. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
2. This is a numbered list item that is
|
||||
very long and needs to be wrapped
|
||||
properly.
|
||||
- This is an unordered list item that is
|
||||
also very long and should not merge
|
||||
with the numbered item.ˇ»
|
||||
"},
|
||||
markdown_language.clone(),
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
@@ -9326,7 +9405,7 @@ async fn test_document_format_during_save(cx: &mut TestAppContext) {
|
||||
|
||||
// Set rust language override and assert overridden tabsize is sent to language server
|
||||
update_test_language_settings(cx, |settings| {
|
||||
settings.languages.insert(
|
||||
settings.languages.0.insert(
|
||||
"Rust".into(),
|
||||
LanguageSettingsContent {
|
||||
tab_size: NonZeroU32::new(8),
|
||||
@@ -9890,7 +9969,7 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) {
|
||||
|
||||
// Set Rust language override and assert overridden tabsize is sent to language server
|
||||
update_test_language_settings(cx, |settings| {
|
||||
settings.languages.insert(
|
||||
settings.languages.0.insert(
|
||||
"Rust".into(),
|
||||
LanguageSettingsContent {
|
||||
tab_size: NonZeroU32::new(8),
|
||||
@@ -9933,9 +10012,9 @@ async fn test_range_format_during_save(cx: &mut TestAppContext) {
|
||||
#[gpui::test]
|
||||
async fn test_document_format_manual_trigger(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(
|
||||
FormatterList(vec![Formatter::LanguageServer { name: None }].into()),
|
||||
))
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
]))
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -10062,21 +10141,17 @@ async fn test_document_format_manual_trigger(cx: &mut TestAppContext) {
|
||||
async fn test_multiple_formatters(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
|
||||
settings.defaults.formatter =
|
||||
Some(language_settings::SelectedFormatter::List(FormatterList(
|
||||
vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
Formatter::CodeActions(
|
||||
[
|
||||
("code-action-1".into(), true),
|
||||
("code-action-2".into(), true),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
Formatter::CodeActions(
|
||||
[
|
||||
("code-action-1".into(), true),
|
||||
("code-action-2".into(), true),
|
||||
]
|
||||
.into(),
|
||||
)))
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
]))
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -10328,9 +10403,9 @@ async fn test_multiple_formatters(cx: &mut TestAppContext) {
|
||||
#[gpui::test]
|
||||
async fn test_organize_imports_manual_trigger(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(
|
||||
FormatterList(vec![Formatter::LanguageServer { name: None }].into()),
|
||||
))
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
]))
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -14905,7 +14980,7 @@ async fn test_language_server_restart_due_to_settings_change(cx: &mut TestAppCon
|
||||
.unwrap();
|
||||
let _fake_server = fake_servers.next().await.unwrap();
|
||||
update_test_language_settings(cx, |language_settings| {
|
||||
language_settings.languages.insert(
|
||||
language_settings.languages.0.insert(
|
||||
language_name.clone(),
|
||||
LanguageSettingsContent {
|
||||
tab_size: NonZeroU32::new(8),
|
||||
@@ -15803,9 +15878,9 @@ fn completion_menu_entries(menu: &CompletionsMenu) -> Vec<String> {
|
||||
#[gpui::test]
|
||||
async fn test_document_format_with_prettier(cx: &mut TestAppContext) {
|
||||
init_test(cx, |settings| {
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(
|
||||
FormatterList(vec![Formatter::Prettier].into()),
|
||||
))
|
||||
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(vec![
|
||||
Formatter::Prettier,
|
||||
]))
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
@@ -1307,7 +1307,7 @@ pub mod tests {
|
||||
use crate::scroll::ScrollAmount;
|
||||
use crate::{ExcerptRange, scroll::Autoscroll, test::editor_lsp_test_context::rust_lang};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AppContext as _, Context, SemanticVersion, TestAppContext, WindowHandle};
|
||||
use gpui::{Context, SemanticVersion, TestAppContext, WindowHandle};
|
||||
use itertools::Itertools as _;
|
||||
use language::{Capability, FakeLspAdapter, language_settings::AllLanguageSettingsContent};
|
||||
use language::{Language, LanguageConfig, LanguageMatcher};
|
||||
|
||||
@@ -626,7 +626,7 @@ mod jsx_tag_autoclose_tests {
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use gpui::TestAppContext;
|
||||
use language::language_settings::JsxTagAutoCloseSettings;
|
||||
use languages::language;
|
||||
use multi_buffer::ExcerptRange;
|
||||
|
||||
3
crates/eval/src/examples/edit.rs
Normal file
3
crates/eval/src/examples/edit.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod basic;
|
||||
|
||||
pub use basic::*;
|
||||
104
crates/eval/src/examples/edit/basic.rs
Normal file
104
crates/eval/src/examples/edit/basic.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use std::{collections::HashSet, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
|
||||
use async_trait::async_trait;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::HashMap;
|
||||
|
||||
use crate::example::{
|
||||
Example, ExampleContext, ExampleMetadata, FileEditHunk, FileEdits, JudgeAssertion,
|
||||
LanguageServer,
|
||||
};
|
||||
|
||||
pub struct EditBasic;
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl Example for EditBasic {
|
||||
fn meta(&self) -> ExampleMetadata {
|
||||
ExampleMetadata {
|
||||
name: "edit_basic".to_string(),
|
||||
url: "https://github.com/zed-industries/zed.git".to_string(),
|
||||
revision: "58604fba86ebbffaa01f7c6834253e33bcd38c0f".to_string(),
|
||||
language_server: None,
|
||||
max_assertions: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
|
||||
cx.push_user_message(format!(
|
||||
r#"
|
||||
Read the `crates/git/src/blame.rs` file and delete `run_git_blame`. Just that
|
||||
one function, not its usages.
|
||||
|
||||
IMPORTANT: You are only allowed to use the `read_file` and `edit_file` tools!
|
||||
"#
|
||||
));
|
||||
|
||||
let response = cx.run_to_end().await?;
|
||||
// let expected_edits = HashMap::from_iter([(
|
||||
// Arc::from(Path::new("crates/git/src/blame.rs")),
|
||||
// FileEdits {
|
||||
// hunks: vec![
|
||||
// FileEditHunk {
|
||||
// base_text: " unique_shas.insert(entry.sha);\n".into(),
|
||||
// text: " unique_shas.insert(entry.git_sha);\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " pub sha: Oid,\n".into(),
|
||||
// text: " pub git_sha: Oid,\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " let sha = parts\n".into(),
|
||||
// text: " let git_sha = parts\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text:
|
||||
// " .ok_or_else(|| anyhow!(\"failed to parse sha\"))?;\n"
|
||||
// .into(),
|
||||
// text:
|
||||
// " .ok_or_else(|| anyhow!(\"failed to parse git_sha\"))?;\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " sha,\n".into(),
|
||||
// text: " git_sha,\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " .get(&new_entry.sha)\n".into(),
|
||||
// text: " .get(&new_entry.git_sha)\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " let is_committed = !entry.sha.is_zero();\n"
|
||||
// .into(),
|
||||
// text: " let is_committed = !entry.git_sha.is_zero();\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " index.insert(entry.sha, entries.len());\n"
|
||||
// .into(),
|
||||
// text: " index.insert(entry.git_sha, entries.len());\n"
|
||||
// .into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// FileEditHunk {
|
||||
// base_text: " if !entry.sha.is_zero() {\n".into(),
|
||||
// text: " if !entry.git_sha.is_zero() {\n".into(),
|
||||
// status: DiffHunkStatus::modified_none(),
|
||||
// },
|
||||
// ],
|
||||
// },
|
||||
// )]);
|
||||
// let actual_edits = cx.edits();
|
||||
// cx.assert_eq(&actual_edits, &expected_edits, "edits don't match")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
|
||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
||||
) => {}
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error, ..
|
||||
}) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!(
|
||||
"**Error**: parse error in tool use JSON: {}\n\n",
|
||||
json_parse_error
|
||||
));
|
||||
}
|
||||
Err(error) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!("**Error**: {}\n\n", error));
|
||||
@@ -1132,6 +1141,17 @@ impl ThreadDialog {
|
||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
||||
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error,
|
||||
..
|
||||
}) => {
|
||||
flush_text(&mut current_text, &mut content);
|
||||
content.push(MessageContent::Text(format!(
|
||||
"ERROR: parse error in tool use JSON: {}",
|
||||
json_parse_error
|
||||
)));
|
||||
}
|
||||
|
||||
Err(error) => {
|
||||
flush_text(&mut current_text, &mut content);
|
||||
content.push(MessageContent::Text(format!("ERROR: {}", error)));
|
||||
|
||||
@@ -8,7 +8,7 @@ use collections::{BTreeMap, HashSet};
|
||||
use extension::ExtensionHostProxy;
|
||||
use fs::{FakeFs, Fs, RealFs};
|
||||
use futures::{AsyncReadExt, StreamExt, io::BufReader};
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
|
||||
use gpui::{SemanticVersion, TestAppContext};
|
||||
use http_client::{FakeHttpClient, Response};
|
||||
use language::{BinaryStatus, LanguageMatcher, LanguageRegistry};
|
||||
use lsp::LanguageServerName;
|
||||
|
||||
@@ -65,6 +65,7 @@ actions!(
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = git, deprecated_aliases = ["editor::RevertFile"])]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct RestoreFile {
|
||||
#[serde(default)]
|
||||
pub skip_prompt: bool,
|
||||
|
||||
@@ -12,7 +12,7 @@ license = "Apache-2.0"
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["http_client", "font-kit", "wayland", "x11"]
|
||||
default = ["http_client", "font-kit", "wayland", "x11", "windows-manifest"]
|
||||
test-support = [
|
||||
"leak-detection",
|
||||
"collections/test-support",
|
||||
@@ -69,7 +69,7 @@ x11 = [
|
||||
"open",
|
||||
"scap",
|
||||
]
|
||||
|
||||
windows-manifest = []
|
||||
|
||||
[lib]
|
||||
path = "src/gpui.rs"
|
||||
|
||||
@@ -17,7 +17,7 @@ fn main() {
|
||||
#[cfg(target_os = "macos")]
|
||||
macos::build();
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
#[cfg(all(target_os = "windows", feature = "windows-manifest"))]
|
||||
Ok("windows") => {
|
||||
let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml");
|
||||
let rc_file = std::path::Path::new("resources/windows/gpui.rc");
|
||||
|
||||
@@ -125,9 +125,7 @@ pub trait Action: Any + Send {
|
||||
Self: Sized;
|
||||
|
||||
/// Optional JSON schema for the action's input data.
|
||||
fn action_json_schema(
|
||||
_: &mut schemars::r#gen::SchemaGenerator,
|
||||
) -> Option<schemars::schema::Schema>
|
||||
fn action_json_schema(_: &mut schemars::SchemaGenerator) -> Option<schemars::Schema>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
@@ -238,7 +236,7 @@ impl Default for ActionRegistry {
|
||||
|
||||
struct ActionData {
|
||||
pub build: ActionBuilder,
|
||||
pub json_schema: fn(&mut schemars::r#gen::SchemaGenerator) -> Option<schemars::schema::Schema>,
|
||||
pub json_schema: fn(&mut schemars::SchemaGenerator) -> Option<schemars::Schema>,
|
||||
}
|
||||
|
||||
/// This type must be public so that our macros can build it in other crates.
|
||||
@@ -253,7 +251,7 @@ pub struct MacroActionData {
|
||||
pub name: &'static str,
|
||||
pub type_id: TypeId,
|
||||
pub build: ActionBuilder,
|
||||
pub json_schema: fn(&mut schemars::r#gen::SchemaGenerator) -> Option<schemars::schema::Schema>,
|
||||
pub json_schema: fn(&mut schemars::SchemaGenerator) -> Option<schemars::Schema>,
|
||||
pub deprecated_aliases: &'static [&'static str],
|
||||
pub deprecation_message: Option<&'static str>,
|
||||
}
|
||||
@@ -357,8 +355,8 @@ impl ActionRegistry {
|
||||
|
||||
pub fn action_schemas(
|
||||
&self,
|
||||
generator: &mut schemars::r#gen::SchemaGenerator,
|
||||
) -> Vec<(&'static str, Option<schemars::schema::Schema>)> {
|
||||
generator: &mut schemars::SchemaGenerator,
|
||||
) -> Vec<(&'static str, Option<schemars::Schema>)> {
|
||||
// Use the order from all_names so that the resulting schema has sensible order.
|
||||
self.all_names
|
||||
.iter()
|
||||
|
||||
@@ -1069,7 +1069,7 @@ impl App {
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtains a reference to the executor, which can be used to spawn futures.
|
||||
/// Obtains a reference to the background executor, which can be used to spawn futures.
|
||||
pub fn background_executor(&self) -> &BackgroundExecutor {
|
||||
&self.background_executor
|
||||
}
|
||||
@@ -1334,6 +1334,11 @@ impl App {
|
||||
self.pending_effects.push_back(Effect::RefreshWindows);
|
||||
}
|
||||
|
||||
/// Get all key bindings in the app.
|
||||
pub fn key_bindings(&self) -> Rc<RefCell<Keymap>> {
|
||||
self.keymap.clone()
|
||||
}
|
||||
|
||||
/// Register a global listener for actions invoked via the keyboard.
|
||||
pub fn on_action<A: Action>(&mut self, listener: impl Fn(&A, &mut Self) + 'static) {
|
||||
self.global_action_listeners
|
||||
@@ -1388,8 +1393,8 @@ impl App {
|
||||
/// Get all non-internal actions that have been registered, along with their schemas.
|
||||
pub fn action_schemas(
|
||||
&self,
|
||||
generator: &mut schemars::r#gen::SchemaGenerator,
|
||||
) -> Vec<(&'static str, Option<schemars::schema::Schema>)> {
|
||||
generator: &mut schemars::SchemaGenerator,
|
||||
) -> Vec<(&'static str, Option<schemars::Schema>)> {
|
||||
self.actions.action_schemas(generator)
|
||||
}
|
||||
|
||||
|
||||
@@ -178,7 +178,14 @@ impl TestAppContext {
|
||||
&self.foreground_executor
|
||||
}
|
||||
|
||||
fn new<T: 'static>(&mut self, build_entity: impl FnOnce(&mut Context<T>) -> T) -> Entity<T> {
|
||||
/// Builds an entity that is owned by the application.
|
||||
///
|
||||
/// The given function will be invoked with a [`Context`] and must return an object representing the entity. An
|
||||
/// [`Entity`] handle will be returned, which can be used to access the entity in a context.
|
||||
pub fn new<T: 'static>(
|
||||
&mut self,
|
||||
build_entity: impl FnOnce(&mut Context<T>) -> T,
|
||||
) -> Entity<T> {
|
||||
let mut cx = self.app.borrow_mut();
|
||||
cx.new(build_entity)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use schemars::{JsonSchema, SchemaGenerator, schema::Schema};
|
||||
use schemars::{JsonSchema, json_schema};
|
||||
use serde::{
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
de::{self, Visitor},
|
||||
};
|
||||
use std::borrow::Cow;
|
||||
use std::{
|
||||
fmt::{self, Display, Formatter},
|
||||
hash::{Hash, Hasher},
|
||||
@@ -99,22 +100,14 @@ impl Visitor<'_> for RgbaVisitor {
|
||||
}
|
||||
|
||||
impl JsonSchema for Rgba {
|
||||
fn schema_name() -> String {
|
||||
"Rgba".to_string()
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"Rgba".into()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
|
||||
use schemars::schema::{InstanceType, SchemaObject, StringValidation};
|
||||
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
string: Some(Box::new(StringValidation {
|
||||
pattern: Some(
|
||||
r"^#([0-9a-fA-F]{3}|[0-9a-fA-F]{4}|[0-9a-fA-F]{6}|[0-9a-fA-F]{8})$".to_string(),
|
||||
),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "string",
|
||||
"pattern": "^#([0-9a-fA-F]{3}|[0-9a-fA-F]{4}|[0-9a-fA-F]{6}|[0-9a-fA-F]{8})$"
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -629,11 +622,11 @@ impl From<Rgba> for Hsla {
|
||||
}
|
||||
|
||||
impl JsonSchema for Hsla {
|
||||
fn schema_name() -> String {
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
Rgba::schema_name()
|
||||
}
|
||||
|
||||
fn json_schema(generator: &mut SchemaGenerator) -> Schema {
|
||||
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
Rgba::json_schema(generator)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -613,10 +613,10 @@ pub trait InteractiveElement: Sized {
|
||||
/// Track the focus state of the given focus handle on this element.
|
||||
/// If the focus handle is focused by the application, this element will
|
||||
/// apply its focused styles.
|
||||
fn track_focus(mut self, focus_handle: &FocusHandle) -> FocusableWrapper<Self> {
|
||||
fn track_focus(mut self, focus_handle: &FocusHandle) -> Self {
|
||||
self.interactivity().focusable = true;
|
||||
self.interactivity().tracked_focus_handle = Some(focus_handle.clone());
|
||||
FocusableWrapper { element: self }
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the keymap context for this element. This will be used to determine
|
||||
@@ -980,15 +980,35 @@ pub trait InteractiveElement: Sized {
|
||||
self.interactivity().block_mouse_except_scroll();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the given styles to be applied when this element, specifically, is focused.
|
||||
/// Requires that the element is focusable. Elements can be made focusable using [`InteractiveElement::track_focus`].
|
||||
fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.interactivity().focus_style = Some(Box::new(f(StyleRefinement::default())));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the given styles to be applied when this element is inside another element that is focused.
|
||||
/// Requires that the element is focusable. Elements can be made focusable using [`InteractiveElement::track_focus`].
|
||||
fn in_focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.interactivity().in_focus_style = Some(Box::new(f(StyleRefinement::default())));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for elements that want to use the standard GPUI interactivity features
|
||||
/// that require state.
|
||||
pub trait StatefulInteractiveElement: InteractiveElement {
|
||||
/// Set this element to focusable.
|
||||
fn focusable(mut self) -> FocusableWrapper<Self> {
|
||||
fn focusable(mut self) -> Self {
|
||||
self.interactivity().focusable = true;
|
||||
FocusableWrapper { element: self }
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the overflow x and y to scroll.
|
||||
@@ -1118,27 +1138,6 @@ pub trait StatefulInteractiveElement: InteractiveElement {
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for providing focus related APIs to interactive elements
|
||||
pub trait FocusableElement: InteractiveElement {
|
||||
/// Set the given styles to be applied when this element, specifically, is focused.
|
||||
fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.interactivity().focus_style = Some(Box::new(f(StyleRefinement::default())));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the given styles to be applied when this element is inside another element that is focused.
|
||||
fn in_focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.interactivity().in_focus_style = Some(Box::new(f(StyleRefinement::default())));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type MouseDownListener =
|
||||
Box<dyn Fn(&MouseDownEvent, DispatchPhase, &Hitbox, &mut Window, &mut App) + 'static>;
|
||||
pub(crate) type MouseUpListener =
|
||||
@@ -2777,126 +2776,6 @@ impl GroupHitboxes {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around an element that can be focused.
|
||||
pub struct FocusableWrapper<E> {
|
||||
/// The element that is focusable
|
||||
pub element: E,
|
||||
}
|
||||
|
||||
impl<E: InteractiveElement> FocusableElement for FocusableWrapper<E> {}
|
||||
|
||||
impl<E> InteractiveElement for FocusableWrapper<E>
|
||||
where
|
||||
E: InteractiveElement,
|
||||
{
|
||||
fn interactivity(&mut self) -> &mut Interactivity {
|
||||
self.element.interactivity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: StatefulInteractiveElement> StatefulInteractiveElement for FocusableWrapper<E> {}
|
||||
|
||||
impl<E> Styled for FocusableWrapper<E>
|
||||
where
|
||||
E: Styled,
|
||||
{
|
||||
fn style(&mut self) -> &mut StyleRefinement {
|
||||
self.element.style()
|
||||
}
|
||||
}
|
||||
|
||||
impl FocusableWrapper<Div> {
|
||||
/// Add a listener to be called when the children of this `Div` are prepainted.
|
||||
/// This allows you to store the [`Bounds`] of the children for later use.
|
||||
pub fn on_children_prepainted(
|
||||
mut self,
|
||||
listener: impl Fn(Vec<Bounds<Pixels>>, &mut Window, &mut App) + 'static,
|
||||
) -> Self {
|
||||
self.element = self.element.on_children_prepainted(listener);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Element for FocusableWrapper<E>
|
||||
where
|
||||
E: Element,
|
||||
{
|
||||
type RequestLayoutState = E::RequestLayoutState;
|
||||
type PrepaintState = E::PrepaintState;
|
||||
|
||||
fn id(&self) -> Option<ElementId> {
|
||||
self.element.id()
|
||||
}
|
||||
|
||||
fn source_location(&self) -> Option<&'static core::panic::Location<'static>> {
|
||||
self.element.source_location()
|
||||
}
|
||||
|
||||
fn request_layout(
|
||||
&mut self,
|
||||
id: Option<&GlobalElementId>,
|
||||
inspector_id: Option<&InspectorElementId>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> (LayoutId, Self::RequestLayoutState) {
|
||||
self.element.request_layout(id, inspector_id, window, cx)
|
||||
}
|
||||
|
||||
fn prepaint(
|
||||
&mut self,
|
||||
id: Option<&GlobalElementId>,
|
||||
inspector_id: Option<&InspectorElementId>,
|
||||
bounds: Bounds<Pixels>,
|
||||
state: &mut Self::RequestLayoutState,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> E::PrepaintState {
|
||||
self.element
|
||||
.prepaint(id, inspector_id, bounds, state, window, cx)
|
||||
}
|
||||
|
||||
fn paint(
|
||||
&mut self,
|
||||
id: Option<&GlobalElementId>,
|
||||
inspector_id: Option<&InspectorElementId>,
|
||||
bounds: Bounds<Pixels>,
|
||||
request_layout: &mut Self::RequestLayoutState,
|
||||
prepaint: &mut Self::PrepaintState,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.element.paint(
|
||||
id,
|
||||
inspector_id,
|
||||
bounds,
|
||||
request_layout,
|
||||
prepaint,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> IntoElement for FocusableWrapper<E>
|
||||
where
|
||||
E: IntoElement,
|
||||
{
|
||||
type Element = E::Element;
|
||||
|
||||
fn into_element(self) -> Self::Element {
|
||||
self.element.into_element()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> ParentElement for FocusableWrapper<E>
|
||||
where
|
||||
E: ParentElement,
|
||||
{
|
||||
fn extend(&mut self, elements: impl IntoIterator<Item = AnyElement>) {
|
||||
self.element.extend(elements)
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around an element that can store state, produced after assigning an ElementId.
|
||||
pub struct Stateful<E> {
|
||||
pub(crate) element: E,
|
||||
@@ -2927,8 +2806,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: FocusableElement> FocusableElement for Stateful<E> {}
|
||||
|
||||
impl<E> Element for Stateful<E>
|
||||
where
|
||||
E: Element,
|
||||
|
||||
@@ -25,7 +25,7 @@ use std::{
|
||||
use thiserror::Error;
|
||||
use util::ResultExt;
|
||||
|
||||
use super::{FocusableElement, Stateful, StatefulInteractiveElement};
|
||||
use super::{Stateful, StatefulInteractiveElement};
|
||||
|
||||
/// The delay before showing the loading state.
|
||||
pub const LOADING_DELAY: Duration = Duration::from_millis(200);
|
||||
@@ -509,8 +509,6 @@ impl IntoElement for Img {
|
||||
}
|
||||
}
|
||||
|
||||
impl FocusableElement for Img {}
|
||||
|
||||
impl StatefulInteractiveElement for Img {}
|
||||
|
||||
impl ImageSource {
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
use crate::{
|
||||
AnyElement, App, AvailableSpace, Bounds, ContentMask, DispatchPhase, Edges, Element, EntityId,
|
||||
FocusHandle, GlobalElementId, Hitbox, HitboxBehavior, InspectorElementId, IntoElement,
|
||||
Overflow, Pixels, Point, ScrollWheelEvent, Size, Style, StyleRefinement, Styled, Window, point,
|
||||
px, size,
|
||||
Overflow, Pixels, Point, ScrollDelta, ScrollWheelEvent, Size, Style, StyleRefinement, Styled,
|
||||
Window, point, px, size,
|
||||
};
|
||||
use collections::VecDeque;
|
||||
use refineable::Refineable as _;
|
||||
@@ -962,12 +962,15 @@ impl Element for List {
|
||||
let height = bounds.size.height;
|
||||
let scroll_top = prepaint.layout.scroll_top;
|
||||
let hitbox_id = prepaint.hitbox.id;
|
||||
let mut accumulated_scroll_delta = ScrollDelta::default();
|
||||
window.on_mouse_event(move |event: &ScrollWheelEvent, phase, window, cx| {
|
||||
if phase == DispatchPhase::Bubble && hitbox_id.should_handle_scroll(window) {
|
||||
accumulated_scroll_delta = accumulated_scroll_delta.coalesce(event.delta);
|
||||
let pixel_delta = accumulated_scroll_delta.pixel_delta(px(20.));
|
||||
list_state.0.borrow_mut().scroll(
|
||||
&scroll_top,
|
||||
height,
|
||||
event.delta.pixel_delta(px(20.)),
|
||||
pixel_delta,
|
||||
current_view,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -95,6 +95,13 @@ where
|
||||
.spawn(self.log_tracked_err(*location))
|
||||
.detach();
|
||||
}
|
||||
|
||||
/// Convert a Task<Result<T, E>> to a Task<()> that logs all errors.
|
||||
pub fn log_err_in_task(self, cx: &App) -> Task<Option<T>> {
|
||||
let location = core::panic::Location::caller();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { self.log_tracked_err(*location).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for Task<T> {
|
||||
|
||||
@@ -6,8 +6,9 @@ use anyhow::{Context as _, anyhow};
|
||||
use core::fmt::Debug;
|
||||
use derive_more::{Add, AddAssign, Div, DivAssign, Mul, Neg, Sub, SubAssign};
|
||||
use refineable::Refineable;
|
||||
use schemars::{JsonSchema, SchemaGenerator, schema::Schema};
|
||||
use schemars::{JsonSchema, json_schema};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
|
||||
use std::borrow::Cow;
|
||||
use std::{
|
||||
cmp::{self, PartialOrd},
|
||||
fmt::{self, Display},
|
||||
@@ -3229,20 +3230,14 @@ impl TryFrom<&'_ str> for AbsoluteLength {
|
||||
}
|
||||
|
||||
impl JsonSchema for AbsoluteLength {
|
||||
fn schema_name() -> String {
|
||||
"AbsoluteLength".to_string()
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"AbsoluteLength".into()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
|
||||
use schemars::schema::{InstanceType, SchemaObject, StringValidation};
|
||||
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
string: Some(Box::new(StringValidation {
|
||||
pattern: Some(r"^-?\d+(\.\d+)?(px|rem)$".to_string()),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "string",
|
||||
"pattern": r"^-?\d+(\.\d+)?(px|rem)$"
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3366,20 +3361,14 @@ impl TryFrom<&'_ str> for DefiniteLength {
|
||||
}
|
||||
|
||||
impl JsonSchema for DefiniteLength {
|
||||
fn schema_name() -> String {
|
||||
"DefiniteLength".to_string()
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"DefiniteLength".into()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
|
||||
use schemars::schema::{InstanceType, SchemaObject, StringValidation};
|
||||
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
string: Some(Box::new(StringValidation {
|
||||
pattern: Some(r"^-?\d+(\.\d+)?(px|rem|%)$".to_string()),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "string",
|
||||
"pattern": r"^-?\d+(\.\d+)?(px|rem|%)$"
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3480,20 +3469,14 @@ impl TryFrom<&'_ str> for Length {
|
||||
}
|
||||
|
||||
impl JsonSchema for Length {
|
||||
fn schema_name() -> String {
|
||||
"Length".to_string()
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"Length".into()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
|
||||
use schemars::schema::{InstanceType, SchemaObject, StringValidation};
|
||||
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
string: Some(Box::new(StringValidation {
|
||||
pattern: Some(r"^(auto|-?\d+(\.\d+)?(px|rem|%))$".to_string()),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "string",
|
||||
"pattern": r"^(auto|-?\d+(\.\d+)?(px|rem|%))$"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::rc::Rc;
|
||||
|
||||
use collections::HashMap;
|
||||
|
||||
use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke};
|
||||
use crate::{Action, InvalidKeystrokeError, KeyBindingContextPredicate, Keystroke, SharedString};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// A keybinding and its associated metadata, from the keymap.
|
||||
@@ -11,6 +11,8 @@ pub struct KeyBinding {
|
||||
pub(crate) keystrokes: SmallVec<[Keystroke; 2]>,
|
||||
pub(crate) context_predicate: Option<Rc<KeyBindingContextPredicate>>,
|
||||
pub(crate) meta: Option<KeyBindingMetaIndex>,
|
||||
/// The json input string used when building the keybinding, if any
|
||||
pub(crate) action_input: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl Clone for KeyBinding {
|
||||
@@ -20,6 +22,7 @@ impl Clone for KeyBinding {
|
||||
keystrokes: self.keystrokes.clone(),
|
||||
context_predicate: self.context_predicate.clone(),
|
||||
meta: self.meta,
|
||||
action_input: self.action_input.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,7 +35,7 @@ impl KeyBinding {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self::load(keystrokes, Box::new(action), context_predicate, None).unwrap()
|
||||
Self::load(keystrokes, Box::new(action), context_predicate, None, None).unwrap()
|
||||
}
|
||||
|
||||
/// Load a keybinding from the given raw data.
|
||||
@@ -41,6 +44,7 @@ impl KeyBinding {
|
||||
action: Box<dyn Action>,
|
||||
context_predicate: Option<Rc<KeyBindingContextPredicate>>,
|
||||
key_equivalents: Option<&HashMap<char, char>>,
|
||||
action_input: Option<SharedString>,
|
||||
) -> std::result::Result<Self, InvalidKeystrokeError> {
|
||||
let mut keystrokes: SmallVec<[Keystroke; 2]> = keystrokes
|
||||
.split_whitespace()
|
||||
@@ -62,6 +66,7 @@ impl KeyBinding {
|
||||
action,
|
||||
context_predicate,
|
||||
meta: None,
|
||||
action_input,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -110,6 +115,11 @@ impl KeyBinding {
|
||||
pub fn meta(&self) -> Option<KeyBindingMetaIndex> {
|
||||
self.meta
|
||||
}
|
||||
|
||||
/// Get the action input associated with the action for this binding
|
||||
pub fn action_input(&self) -> Option<SharedString> {
|
||||
self.action_input.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for KeyBinding {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! application to avoid having to import each trait individually.
|
||||
|
||||
pub use crate::{
|
||||
AppContext as _, BorrowAppContext, Context, Element, FocusableElement, InteractiveElement,
|
||||
IntoElement, ParentElement, Refineable, Render, RenderOnce, StatefulInteractiveElement, Styled,
|
||||
StyledImage, VisualContext, util::FluentBuilder,
|
||||
AppContext as _, BorrowAppContext, Context, Element, InteractiveElement, IntoElement,
|
||||
ParentElement, Refineable, Render, RenderOnce, StatefulInteractiveElement, Styled, StyledImage,
|
||||
VisualContext, util::FluentBuilder,
|
||||
};
|
||||
|
||||
@@ -2,7 +2,10 @@ use derive_more::{Deref, DerefMut};
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Borrow, sync::Arc};
|
||||
use std::{
|
||||
borrow::{Borrow, Cow},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::arc_cow::ArcCow;
|
||||
|
||||
/// A shared string is an immutable string that can be cheaply cloned in GPUI
|
||||
@@ -23,12 +26,16 @@ impl SharedString {
|
||||
}
|
||||
|
||||
impl JsonSchema for SharedString {
|
||||
fn schema_name() -> String {
|
||||
fn inline_schema() -> bool {
|
||||
String::inline_schema()
|
||||
}
|
||||
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
String::schema_name()
|
||||
}
|
||||
|
||||
fn json_schema(r#gen: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
|
||||
String::json_schema(r#gen)
|
||||
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
String::json_schema(generator)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
use schemars::schema::{InstanceType, SchemaObject};
|
||||
use schemars::{JsonSchema, json_schema};
|
||||
|
||||
/// The OpenType features that can be configured for a given font.
|
||||
#[derive(Default, Clone, Eq, PartialEq, Hash)]
|
||||
@@ -128,36 +129,23 @@ impl serde::Serialize for FontFeatures {
|
||||
}
|
||||
}
|
||||
|
||||
impl schemars::JsonSchema for FontFeatures {
|
||||
fn schema_name() -> String {
|
||||
impl JsonSchema for FontFeatures {
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"FontFeatures".into()
|
||||
}
|
||||
|
||||
fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
|
||||
let mut schema = SchemaObject::default();
|
||||
schema.instance_type = Some(schemars::schema::SingleOrVec::Single(Box::new(
|
||||
InstanceType::Object,
|
||||
)));
|
||||
{
|
||||
let mut property = SchemaObject {
|
||||
instance_type: Some(schemars::schema::SingleOrVec::Vec(vec![
|
||||
InstanceType::Boolean,
|
||||
InstanceType::Integer,
|
||||
])),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
{
|
||||
let mut number_constraints = property.number();
|
||||
number_constraints.multiple_of = Some(1.0);
|
||||
number_constraints.minimum = Some(0.0);
|
||||
}
|
||||
schema
|
||||
.object()
|
||||
.pattern_properties
|
||||
.insert("[0-9a-zA-Z]{4}$".into(), property.into());
|
||||
}
|
||||
schema.into()
|
||||
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"[0-9a-zA-Z]{4}$": {
|
||||
"type": ["boolean", "integer"],
|
||||
"minimum": 0,
|
||||
"multipleOf": 1
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,9 +16,11 @@ fn test_action_macros() {
|
||||
|
||||
#[derive(PartialEq, Clone, Deserialize, JsonSchema, Action)]
|
||||
#[action(namespace = test_only)]
|
||||
struct AnotherSomeAction;
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct AnotherAction;
|
||||
|
||||
#[derive(PartialEq, Clone, gpui::private::serde_derive::Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
struct RegisterableAction {}
|
||||
|
||||
register_action!(RegisterableAction);
|
||||
|
||||
@@ -159,8 +159,8 @@ pub(crate) fn derive_action(input: TokenStream) -> TokenStream {
|
||||
}
|
||||
|
||||
fn action_json_schema(
|
||||
_generator: &mut gpui::private::schemars::r#gen::SchemaGenerator,
|
||||
) -> Option<gpui::private::schemars::schema::Schema> {
|
||||
_generator: &mut gpui::private::schemars::SchemaGenerator,
|
||||
) -> Option<gpui::private::schemars::Schema> {
|
||||
#json_schema_fn_body
|
||||
}
|
||||
|
||||
|
||||
@@ -967,6 +967,7 @@ fn toggle_show_inline_completions_for_language(
|
||||
all_language_settings(None, cx).show_edit_predictions(Some(&language), cx);
|
||||
update_settings_file::<AllLanguageSettings>(fs, cx, move |file, _| {
|
||||
file.languages
|
||||
.0
|
||||
.entry(language.name())
|
||||
.or_default()
|
||||
.show_edit_predictions = Some(!show_edit_predictions);
|
||||
|
||||
@@ -39,6 +39,7 @@ globset.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
imara-diff.workspace = true
|
||||
inventory.workspace = true
|
||||
itertools.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
|
||||
@@ -2006,7 +2006,7 @@ fn test_autoindent_language_without_indents_query(cx: &mut App) {
|
||||
#[gpui::test]
|
||||
fn test_autoindent_with_injected_languages(cx: &mut App) {
|
||||
init_settings(cx, |settings| {
|
||||
settings.languages.extend([
|
||||
settings.languages.0.extend([
|
||||
(
|
||||
"HTML".into(),
|
||||
LanguageSettingsContent {
|
||||
|
||||
@@ -39,11 +39,7 @@ use lsp::{CodeActionKind, InitializeParams, LanguageServerBinary, LanguageServer
|
||||
pub use manifest::{ManifestDelegate, ManifestName, ManifestProvider, ManifestQuery};
|
||||
use parking_lot::Mutex;
|
||||
use regex::Regex;
|
||||
use schemars::{
|
||||
JsonSchema,
|
||||
r#gen::SchemaGenerator,
|
||||
schema::{InstanceType, Schema, SchemaObject},
|
||||
};
|
||||
use schemars::{JsonSchema, SchemaGenerator, json_schema};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
|
||||
use serde_json::Value;
|
||||
use settings::WorktreeId;
|
||||
@@ -694,7 +690,6 @@ pub struct LanguageConfig {
|
||||
pub matcher: LanguageMatcher,
|
||||
/// List of bracket types in a language.
|
||||
#[serde(default)]
|
||||
#[schemars(schema_with = "bracket_pair_config_json_schema")]
|
||||
pub brackets: BracketPairConfig,
|
||||
/// If set to true, auto indentation uses last non empty line to determine
|
||||
/// the indentation level for a new line.
|
||||
@@ -735,6 +730,13 @@ pub struct LanguageConfig {
|
||||
/// Starting and closing characters of a block comment.
|
||||
#[serde(default)]
|
||||
pub block_comment: Option<(Arc<str>, Arc<str>)>,
|
||||
/// A list of additional regex patterns that should be treated as prefixes
|
||||
/// for creating boundaries during rewrapping, ensuring content from one
|
||||
/// prefixed section doesn't merge with another (e.g., markdown list items).
|
||||
/// By default, Zed treats as paragraph and comment prefixes as boundaries.
|
||||
#[serde(default, deserialize_with = "deserialize_regex_vec")]
|
||||
#[schemars(schema_with = "regex_vec_json_schema")]
|
||||
pub rewrap_prefixes: Vec<Regex>,
|
||||
/// A list of language servers that are allowed to run on subranges of a given language.
|
||||
#[serde(default)]
|
||||
pub scope_opt_in_language_servers: Vec<LanguageServerName>,
|
||||
@@ -914,6 +916,7 @@ impl Default for LanguageConfig {
|
||||
autoclose_before: Default::default(),
|
||||
line_comments: Default::default(),
|
||||
block_comment: Default::default(),
|
||||
rewrap_prefixes: Default::default(),
|
||||
scope_opt_in_language_servers: Default::default(),
|
||||
overrides: Default::default(),
|
||||
word_characters: Default::default(),
|
||||
@@ -944,10 +947,9 @@ fn deserialize_regex<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Regex>, D
|
||||
}
|
||||
}
|
||||
|
||||
fn regex_json_schema(_: &mut SchemaGenerator) -> Schema {
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
..Default::default()
|
||||
fn regex_json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "string"
|
||||
})
|
||||
}
|
||||
|
||||
@@ -961,6 +963,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_regex_vec<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Regex>, D::Error> {
|
||||
let sources = Vec::<String>::deserialize(d)?;
|
||||
let mut regexes = Vec::new();
|
||||
for source in sources {
|
||||
regexes.push(regex::Regex::new(&source).map_err(de::Error::custom)?);
|
||||
}
|
||||
Ok(regexes)
|
||||
}
|
||||
|
||||
fn regex_vec_json_schema(_: &mut SchemaGenerator) -> schemars::Schema {
|
||||
json_schema!({
|
||||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
})
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub struct FakeLspAdapter {
|
||||
@@ -988,12 +1006,12 @@ pub struct FakeLspAdapter {
|
||||
/// This struct includes settings for defining which pairs of characters are considered brackets and
|
||||
/// also specifies any language-specific scopes where these pairs should be ignored for bracket matching purposes.
|
||||
#[derive(Clone, Debug, Default, JsonSchema)]
|
||||
#[schemars(with = "Vec::<BracketPairContent>")]
|
||||
pub struct BracketPairConfig {
|
||||
/// A list of character pairs that should be treated as brackets in the context of a given language.
|
||||
pub pairs: Vec<BracketPair>,
|
||||
/// A list of tree-sitter scopes for which a given bracket should not be active.
|
||||
/// N-th entry in `[Self::disabled_scopes_by_bracket_ix]` contains a list of disabled scopes for an n-th entry in `[Self::pairs]`
|
||||
#[serde(skip)]
|
||||
pub disabled_scopes_by_bracket_ix: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
@@ -1003,10 +1021,6 @@ impl BracketPairConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn bracket_pair_config_json_schema(r#gen: &mut SchemaGenerator) -> Schema {
|
||||
Option::<Vec<BracketPairContent>>::json_schema(r#gen)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct BracketPairContent {
|
||||
#[serde(flatten)]
|
||||
@@ -1841,6 +1855,14 @@ impl LanguageScope {
|
||||
.map(|e| (&e.0, &e.1))
|
||||
}
|
||||
|
||||
/// Returns additional regex patterns that act as prefix markers for creating
|
||||
/// boundaries during rewrapping.
|
||||
///
|
||||
/// By default, Zed treats as paragraph and comment prefixes as boundaries.
|
||||
pub fn rewrap_prefixes(&self) -> &[Regex] {
|
||||
&self.language.config.rewrap_prefixes
|
||||
}
|
||||
|
||||
/// Returns a list of language-specific word characters.
|
||||
///
|
||||
/// By default, Zed treats alphanumeric characters (and '_') as word characters for
|
||||
|
||||
@@ -1170,7 +1170,7 @@ impl LanguageRegistryState {
|
||||
if let Some(theme) = self.theme.as_ref() {
|
||||
language.set_theme(theme.syntax());
|
||||
}
|
||||
self.language_settings.languages.insert(
|
||||
self.language_settings.languages.0.insert(
|
||||
language.name(),
|
||||
LanguageSettingsContent {
|
||||
tab_size: language.config.tab_size,
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
use crate::{File, Language, LanguageName, LanguageServerName};
|
||||
use anyhow::Result;
|
||||
use collections::{FxHashMap, HashMap, HashSet};
|
||||
use core::slice;
|
||||
use ec4rs::{
|
||||
Properties as EditorconfigProperties,
|
||||
property::{FinalNewline, IndentSize, IndentStyle, TabWidth, TrimTrailingWs},
|
||||
@@ -11,17 +10,15 @@ use ec4rs::{
|
||||
use globset::{Glob, GlobMatcher, GlobSet, GlobSetBuilder};
|
||||
use gpui::{App, Modifiers};
|
||||
use itertools::{Either, Itertools};
|
||||
use schemars::{
|
||||
JsonSchema,
|
||||
schema::{InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec},
|
||||
};
|
||||
use schemars::{JsonSchema, json_schema};
|
||||
use serde::{
|
||||
Deserialize, Deserializer, Serialize,
|
||||
de::{self, IntoDeserializer, MapAccess, SeqAccess, Visitor},
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
use settings::{
|
||||
Settings, SettingsLocation, SettingsSources, SettingsStore, add_references_to_properties,
|
||||
ParameterizedJsonSchema, Settings, SettingsLocation, SettingsSources, SettingsStore,
|
||||
replace_subschema,
|
||||
};
|
||||
use shellexpand;
|
||||
use std::{borrow::Cow, num::NonZeroU32, path::Path, sync::Arc};
|
||||
@@ -306,13 +303,42 @@ pub struct AllLanguageSettingsContent {
|
||||
pub defaults: LanguageSettingsContent,
|
||||
/// The settings for individual languages.
|
||||
#[serde(default)]
|
||||
pub languages: HashMap<LanguageName, LanguageSettingsContent>,
|
||||
pub languages: LanguageToSettingsMap,
|
||||
/// Settings for associating file extensions and filenames
|
||||
/// with languages.
|
||||
#[serde(default)]
|
||||
pub file_types: HashMap<Arc<str>, Vec<String>>,
|
||||
}
|
||||
|
||||
/// Map from language name to settings. Its `ParameterizedJsonSchema` allows only known language
|
||||
/// names in the keys.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LanguageToSettingsMap(pub HashMap<LanguageName, LanguageSettingsContent>);
|
||||
|
||||
inventory::submit! {
|
||||
ParameterizedJsonSchema {
|
||||
add_and_get_ref: |generator, params, _cx| {
|
||||
let language_settings_content_ref = generator
|
||||
.subschema_for::<LanguageSettingsContent>()
|
||||
.to_value();
|
||||
let schema = json_schema!({
|
||||
"type": "object",
|
||||
"properties": params
|
||||
.language_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
(
|
||||
name.clone(),
|
||||
language_settings_content_ref.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<serde_json::Map<_, _>>()
|
||||
});
|
||||
replace_subschema::<LanguageToSettingsMap>(generator, schema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls how completions are processed for this language.
|
||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -384,7 +410,6 @@ fn default_lsp_fetch_timeout_ms() -> u64 {
|
||||
|
||||
/// The settings for a particular language.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct LanguageSettingsContent {
|
||||
/// How many columns a tab should occupy.
|
||||
///
|
||||
@@ -648,45 +673,30 @@ pub enum FormatOnSave {
|
||||
On,
|
||||
/// Files should not be formatted on save.
|
||||
Off,
|
||||
List(FormatterList),
|
||||
List(Vec<Formatter>),
|
||||
}
|
||||
|
||||
impl JsonSchema for FormatOnSave {
|
||||
fn schema_name() -> String {
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"OnSaveFormatter".into()
|
||||
}
|
||||
|
||||
fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> Schema {
|
||||
let mut schema = SchemaObject::default();
|
||||
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
let formatter_schema = Formatter::json_schema(generator);
|
||||
schema.instance_type = Some(
|
||||
vec![
|
||||
InstanceType::Object,
|
||||
InstanceType::String,
|
||||
InstanceType::Array,
|
||||
|
||||
json_schema!({
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": formatter_schema
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"enum": ["on", "off", "prettier", "language_server"]
|
||||
},
|
||||
formatter_schema
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
|
||||
let valid_raw_values = SchemaObject {
|
||||
enum_values: Some(vec![
|
||||
Value::String("on".into()),
|
||||
Value::String("off".into()),
|
||||
Value::String("prettier".into()),
|
||||
Value::String("language_server".into()),
|
||||
]),
|
||||
..Default::default()
|
||||
};
|
||||
let mut nested_values = SchemaObject::default();
|
||||
|
||||
nested_values.array().items = Some(formatter_schema.clone().into());
|
||||
|
||||
schema.subschemas().any_of = Some(vec![
|
||||
nested_values.into(),
|
||||
valid_raw_values.into(),
|
||||
formatter_schema,
|
||||
]);
|
||||
schema.into()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -725,11 +735,11 @@ impl<'de> Deserialize<'de> for FormatOnSave {
|
||||
} else if v == "off" {
|
||||
Ok(Self::Value::Off)
|
||||
} else if v == "language_server" {
|
||||
Ok(Self::Value::List(FormatterList(
|
||||
Formatter::LanguageServer { name: None }.into(),
|
||||
)))
|
||||
Ok(Self::Value::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]))
|
||||
} else {
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(v.into_deserializer());
|
||||
ret.map(Self::Value::List)
|
||||
}
|
||||
@@ -738,7 +748,7 @@ impl<'de> Deserialize<'de> for FormatOnSave {
|
||||
where
|
||||
A: MapAccess<'d>,
|
||||
{
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map));
|
||||
ret.map(Self::Value::List)
|
||||
}
|
||||
@@ -746,7 +756,7 @@ impl<'de> Deserialize<'de> for FormatOnSave {
|
||||
where
|
||||
A: SeqAccess<'d>,
|
||||
{
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(de::value::SeqAccessDeserializer::new(map));
|
||||
ret.map(Self::Value::List)
|
||||
}
|
||||
@@ -783,45 +793,30 @@ pub enum SelectedFormatter {
|
||||
/// or falling back to formatting via language server.
|
||||
#[default]
|
||||
Auto,
|
||||
List(FormatterList),
|
||||
List(Vec<Formatter>),
|
||||
}
|
||||
|
||||
impl JsonSchema for SelectedFormatter {
|
||||
fn schema_name() -> String {
|
||||
fn schema_name() -> Cow<'static, str> {
|
||||
"Formatter".into()
|
||||
}
|
||||
|
||||
fn json_schema(generator: &mut schemars::r#gen::SchemaGenerator) -> Schema {
|
||||
let mut schema = SchemaObject::default();
|
||||
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
let formatter_schema = Formatter::json_schema(generator);
|
||||
schema.instance_type = Some(
|
||||
vec![
|
||||
InstanceType::Object,
|
||||
InstanceType::String,
|
||||
InstanceType::Array,
|
||||
|
||||
json_schema!({
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": formatter_schema
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"enum": ["auto", "prettier", "language_server"]
|
||||
},
|
||||
formatter_schema
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
|
||||
let valid_raw_values = SchemaObject {
|
||||
enum_values: Some(vec![
|
||||
Value::String("auto".into()),
|
||||
Value::String("prettier".into()),
|
||||
Value::String("language_server".into()),
|
||||
]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut nested_values = SchemaObject::default();
|
||||
|
||||
nested_values.array().items = Some(formatter_schema.clone().into());
|
||||
|
||||
schema.subschemas().any_of = Some(vec![
|
||||
nested_values.into(),
|
||||
valid_raw_values.into(),
|
||||
formatter_schema,
|
||||
]);
|
||||
schema.into()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -836,6 +831,7 @@ impl Serialize for SelectedFormatter {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for SelectedFormatter {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
@@ -856,11 +852,11 @@ impl<'de> Deserialize<'de> for SelectedFormatter {
|
||||
if v == "auto" {
|
||||
Ok(Self::Value::Auto)
|
||||
} else if v == "language_server" {
|
||||
Ok(Self::Value::List(FormatterList(
|
||||
Formatter::LanguageServer { name: None }.into(),
|
||||
)))
|
||||
Ok(Self::Value::List(vec![Formatter::LanguageServer {
|
||||
name: None,
|
||||
}]))
|
||||
} else {
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(v.into_deserializer());
|
||||
ret.map(SelectedFormatter::List)
|
||||
}
|
||||
@@ -869,7 +865,7 @@ impl<'de> Deserialize<'de> for SelectedFormatter {
|
||||
where
|
||||
A: MapAccess<'d>,
|
||||
{
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(de::value::MapAccessDeserializer::new(map));
|
||||
ret.map(SelectedFormatter::List)
|
||||
}
|
||||
@@ -877,7 +873,7 @@ impl<'de> Deserialize<'de> for SelectedFormatter {
|
||||
where
|
||||
A: SeqAccess<'d>,
|
||||
{
|
||||
let ret: Result<FormatterList, _> =
|
||||
let ret: Result<Vec<Formatter>, _> =
|
||||
Deserialize::deserialize(de::value::SeqAccessDeserializer::new(map));
|
||||
ret.map(SelectedFormatter::List)
|
||||
}
|
||||
@@ -885,19 +881,6 @@ impl<'de> Deserialize<'de> for SelectedFormatter {
|
||||
deserializer.deserialize_any(FormatDeserializer)
|
||||
}
|
||||
}
|
||||
/// Controls which formatter should be used when formatting code.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case", transparent)]
|
||||
pub struct FormatterList(pub SingleOrVec<Formatter>);
|
||||
|
||||
impl AsRef<[Formatter]> for FormatterList {
|
||||
fn as_ref(&self) -> &[Formatter] {
|
||||
match &self.0 {
|
||||
SingleOrVec::Single(single) => slice::from_ref(single),
|
||||
SingleOrVec::Vec(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls which formatter should be used when formatting code. If there are multiple formatters, they are executed in the order of declaration.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
|
||||
@@ -1209,7 +1192,7 @@ impl settings::Settings for AllLanguageSettings {
|
||||
serde_json::from_value(serde_json::to_value(&default_value.defaults)?)?;
|
||||
|
||||
let mut languages = HashMap::default();
|
||||
for (language_name, settings) in &default_value.languages {
|
||||
for (language_name, settings) in &default_value.languages.0 {
|
||||
let mut language_settings = defaults.clone();
|
||||
merge_settings(&mut language_settings, settings);
|
||||
languages.insert(language_name.clone(), language_settings);
|
||||
@@ -1310,7 +1293,7 @@ impl settings::Settings for AllLanguageSettings {
|
||||
}
|
||||
|
||||
// A user's language-specific settings override default language-specific settings.
|
||||
for (language_name, user_language_settings) in &user_settings.languages {
|
||||
for (language_name, user_language_settings) in &user_settings.languages.0 {
|
||||
merge_settings(
|
||||
languages
|
||||
.entry(language_name.clone())
|
||||
@@ -1366,51 +1349,6 @@ impl settings::Settings for AllLanguageSettings {
|
||||
})
|
||||
}
|
||||
|
||||
fn json_schema(
|
||||
generator: &mut schemars::r#gen::SchemaGenerator,
|
||||
params: &settings::SettingsJsonSchemaParams,
|
||||
_: &App,
|
||||
) -> schemars::schema::RootSchema {
|
||||
let mut root_schema = generator.root_schema_for::<Self::FileContent>();
|
||||
|
||||
// Create a schema for a 'languages overrides' object, associating editor
|
||||
// settings with specific languages.
|
||||
assert!(
|
||||
root_schema
|
||||
.definitions
|
||||
.contains_key("LanguageSettingsContent")
|
||||
);
|
||||
|
||||
let languages_object_schema = SchemaObject {
|
||||
instance_type: Some(InstanceType::Object.into()),
|
||||
object: Some(Box::new(ObjectValidation {
|
||||
properties: params
|
||||
.language_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
(
|
||||
name.clone(),
|
||||
Schema::new_ref("#/definitions/LanguageSettingsContent".into()),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
root_schema
|
||||
.definitions
|
||||
.extend([("Languages".into(), languages_object_schema.into())]);
|
||||
|
||||
add_references_to_properties(
|
||||
&mut root_schema,
|
||||
&[("languages", "#/definitions/Languages")],
|
||||
);
|
||||
|
||||
root_schema
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
let d = &mut current.defaults;
|
||||
if let Some(size) = vscode
|
||||
@@ -1674,29 +1612,26 @@ mod tests {
|
||||
let settings: LanguageSettingsContent = serde_json::from_str(raw).unwrap();
|
||||
assert_eq!(
|
||||
settings.formatter,
|
||||
Some(SelectedFormatter::List(FormatterList(
|
||||
Formatter::LanguageServer { name: None }.into()
|
||||
)))
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None
|
||||
}]))
|
||||
);
|
||||
let raw = "{\"formatter\": [{\"language_server\": {\"name\": null}}]}";
|
||||
let settings: LanguageSettingsContent = serde_json::from_str(raw).unwrap();
|
||||
assert_eq!(
|
||||
settings.formatter,
|
||||
Some(SelectedFormatter::List(FormatterList(
|
||||
vec![Formatter::LanguageServer { name: None }].into()
|
||||
)))
|
||||
Some(SelectedFormatter::List(vec![Formatter::LanguageServer {
|
||||
name: None
|
||||
}]))
|
||||
);
|
||||
let raw = "{\"formatter\": [{\"language_server\": {\"name\": null}}, \"prettier\"]}";
|
||||
let settings: LanguageSettingsContent = serde_json::from_str(raw).unwrap();
|
||||
assert_eq!(
|
||||
settings.formatter,
|
||||
Some(SelectedFormatter::List(FormatterList(
|
||||
vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
Formatter::Prettier
|
||||
]
|
||||
.into()
|
||||
)))
|
||||
Some(SelectedFormatter::List(vec![
|
||||
Formatter::LanguageServer { name: None },
|
||||
Formatter::Prettier
|
||||
]))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -165,10 +165,6 @@ impl LanguageModel for FakeLanguageModel {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0 // No image support
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
"fake".to_string()
|
||||
}
|
||||
|
||||
@@ -9,17 +9,18 @@ mod telemetry;
|
||||
pub mod fake_provider;
|
||||
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::http;
|
||||
use http_client::{StatusCode, http};
|
||||
use icons::IconName;
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
@@ -34,11 +35,22 @@ pub use crate::request::*;
|
||||
pub use crate::role::*;
|
||||
pub use crate::telemetry::*;
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
||||
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
|
||||
LanguageModelProviderId::new("anthropic");
|
||||
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Anthropic");
|
||||
|
||||
/// If we get a rate limit error that doesn't tell us when we can retry,
|
||||
/// default to waiting this long before retrying.
|
||||
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
|
||||
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
|
||||
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Google AI");
|
||||
|
||||
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
|
||||
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("OpenAI");
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
|
||||
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Zed");
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
init_settings(cx);
|
||||
@@ -71,69 +83,194 @@ pub enum LanguageModelCompletionEvent {
|
||||
data: String,
|
||||
},
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolUseJsonParseError {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
role: Role,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("rate limit exceeded, retry after {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
#[error("language model provider's API is overloaded")]
|
||||
Overloaded,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
#[error("invalid request format to language model provider's API")]
|
||||
BadRequestFormat,
|
||||
#[error("authentication error with language model provider's API")]
|
||||
AuthenticationError,
|
||||
#[error("permission error with language model provider's API")]
|
||||
PermissionError,
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound,
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("internal server error in language model provider's API")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error reading response from language model provider's API: {0:?}")]
|
||||
ApiReadResponseError(io::Error),
|
||||
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
|
||||
HttpResponseError { status: u16, body: String },
|
||||
#[error("error serializing request to language model provider API: {0}")]
|
||||
SerializeRequest(serde_json::Error),
|
||||
#[error("error building request body to language model provider API: {0}")]
|
||||
BuildRequestBody(http::Error),
|
||||
#[error("error sending HTTP request to language model provider API: {0}")]
|
||||
HttpSend(anyhow::Error),
|
||||
#[error("error deserializing language model provider API response: {0}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("unexpected language model provider API response format: {0}")]
|
||||
UnknownResponseFormat(String),
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Client errors
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
|
||||
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
if let Some(tokens) = parse_prompt_too_long(&message) {
|
||||
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
|
||||
// to be reported. This is a temporary workaround to handle this in the case where the
|
||||
// token limit has been exceeded.
|
||||
Self::PromptTooLarge {
|
||||
tokens: Some(tokens),
|
||||
}
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend(error),
|
||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
|
||||
AnthropicError::HttpResponseError { status, body } => {
|
||||
Self::HttpResponseError { status, body }
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: retry_after,
|
||||
},
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,23 +278,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
|
||||
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
|
||||
PermissionError => LanguageModelCompletionError::PermissionError,
|
||||
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
|
||||
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => LanguageModelCompletionError::ApiInternalServerError,
|
||||
OverloadedError => LanguageModelCompletionError::Overloaded,
|
||||
},
|
||||
None => LanguageModelCompletionError::Other(error.into()),
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -278,19 +431,21 @@ pub trait LanguageModel: Send + Sync {
|
||||
fn name(&self) -> LanguageModelName;
|
||||
fn provider_id(&self) -> LanguageModelProviderId;
|
||||
fn provider_name(&self) -> LanguageModelProviderName;
|
||||
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||
self.provider_id()
|
||||
}
|
||||
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||
self.provider_name()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String;
|
||||
|
||||
fn api_key(&self, _cx: &App) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Whether this model supports images. This is determined by whether self.max_image_size() is positive.
|
||||
fn supports_images(&self) -> bool {
|
||||
self.max_image_size() > 0
|
||||
}
|
||||
|
||||
/// The maximum image size the model accepts, in bytes. (Zero means images are unsupported.)
|
||||
fn max_image_size(&self) -> u64;
|
||||
/// Whether this model supports images
|
||||
fn supports_images(&self) -> bool;
|
||||
|
||||
/// Whether this model supports tools.
|
||||
fn supports_tools(&self) -> bool;
|
||||
@@ -346,7 +501,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
|
||||
if let Some(first_event) = events.next().await {
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id, .. }) => {
|
||||
message_id = Some(id.clone());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
@@ -370,6 +525,9 @@ pub trait LanguageModel: Send + Sync {
|
||||
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
..
|
||||
}) => None,
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
@@ -400,39 +558,6 @@ pub trait LanguageModel: Send + Sync {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LanguageModelKnownError {
|
||||
#[error("Context window limit exceeded ({tokens})")]
|
||||
ContextWindowLimitExceeded { tokens: u64 },
|
||||
#[error("Language model provider's API is currently overloaded")]
|
||||
Overloaded,
|
||||
#[error("Language model provider's API encountered an internal server error")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error while reading response from language model provider's API: {0:?}")]
|
||||
ReadResponseError(io::Error),
|
||||
#[error("Error deserializing response from language model provider's API: {0:?}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("Language model provider's API returned a response in an unknown format")]
|
||||
UnknownResponseFormat(String),
|
||||
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
}
|
||||
|
||||
impl LanguageModelKnownError {
|
||||
/// Attempts to map an HTTP response status code to a known error type.
|
||||
/// Returns None if the status code doesn't map to a specific known error.
|
||||
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
|
||||
match status {
|
||||
429 => Some(Self::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
}),
|
||||
503 => Some(Self::Overloaded),
|
||||
500..=599 => Some(Self::ApiInternalServerError),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
@@ -478,7 +603,7 @@ pub trait LanguageModelProvider: 'static {
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum LanguageModelProviderTosView {
|
||||
/// When there are some past interactions in the Agent Panel.
|
||||
ThreadtEmptyState,
|
||||
ThreadEmptyState,
|
||||
/// When there are no past interactions in the Agent Panel.
|
||||
ThreadFreshStart,
|
||||
PromptEditorPopup,
|
||||
@@ -514,12 +639,30 @@ pub struct LanguageModelProviderId(pub SharedString);
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
impl LanguageModelProviderId {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderName {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderName {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
|
||||
@@ -98,7 +98,7 @@ impl ConfiguredModel {
|
||||
}
|
||||
|
||||
pub fn is_provided_by_zed(&self) -> bool {
|
||||
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
|
||||
self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use gpui::{
|
||||
use image::codecs::png::PngEncoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
use zed_llm_client::{CompletionIntent, CompletionMode};
|
||||
pub use zed_llm_client::{CompletionIntent, CompletionMode};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
@@ -344,6 +344,24 @@ impl From<&str> for MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolUse> for MessageContent {
|
||||
fn from(value: LanguageModelToolUse) -> Self {
|
||||
MessageContent::ToolUse(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelImage> for MessageContent {
|
||||
fn from(value: LanguageModelImage) -> Self {
|
||||
MessageContent::Image(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanguageModelToolResult> for MessageContent {
|
||||
fn from(value: LanguageModelToolResult) -> Self {
|
||||
MessageContent::ToolResult(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Hash)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
|
||||
@@ -36,6 +36,29 @@ impl Role {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anthropic::Role> for Role {
|
||||
fn from(role: anthropic::Role) -> Self {
|
||||
match role {
|
||||
anthropic::Role::User => Role::User,
|
||||
anthropic::Role::Assistant => Role::Assistant,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Role> for anthropic::Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(role: Role) -> Result<Self, Self::Error> {
|
||||
match role {
|
||||
Role::User => Ok(anthropic::Role::User),
|
||||
Role::Assistant => Ok(anthropic::Role::Assistant),
|
||||
Role::System => Err(anyhow::anyhow!(
|
||||
"System role is not supported in anthropic API"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::ANTHROPIC_PROVIDER_ID;
|
||||
use anthropic::ANTHROPIC_API_URL;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use client::telemetry::Telemetry;
|
||||
@@ -8,8 +9,6 @@ use std::sync::Arc;
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use util::ResultExt;
|
||||
|
||||
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
||||
|
||||
pub fn report_assistant_event(
|
||||
event: AssistantEventData,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
@@ -19,7 +18,7 @@ pub fn report_assistant_event(
|
||||
) {
|
||||
if let Some(telemetry) = telemetry.as_ref() {
|
||||
telemetry.report_assistant_event(event.clone());
|
||||
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
|
||||
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
|
||||
if let Some(api_key) = model_api_key {
|
||||
executor
|
||||
.spawn(async move {
|
||||
|
||||
@@ -20,8 +20,10 @@ aws-credential-types = { workspace = true, features = [
|
||||
] }
|
||||
aws_http_client.workspace = true
|
||||
bedrock.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
copilot.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
|
||||
@@ -33,8 +33,8 @@ use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct AnthropicSettings {
|
||||
@@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -403,7 +403,11 @@ impl AnthropicModel {
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.context("Missing Anthropic API Key")?;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request =
|
||||
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
request.await.map_err(Into::into)
|
||||
@@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -437,13 +441,6 @@ impl LanguageModel for AnthropicModel {
|
||||
true
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
// Anthropic documentation: https://docs.anthropic.com/en/docs/build-with-claude/vision#faq
|
||||
// FAQ section: "Is there a limit to the image file size I can upload?"
|
||||
// "API: Maximum 5MB per image"
|
||||
5_242_880 // 5 MiB - Anthropic's stated maximum
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto
|
||||
@@ -813,12 +810,14 @@ impl AnthropicEventMapper {
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
}),
|
||||
Err(json_parse_err) => {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
vec![event_result]
|
||||
@@ -834,6 +833,7 @@ impl AnthropicEventMapper {
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
message_id: message.id,
|
||||
role: message.role.into(),
|
||||
}),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -46,14 +46,13 @@ use settings::{Settings, SettingsStore};
|
||||
use smol::lock::OnceCell;
|
||||
use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
|
||||
use theme::ThemeSettings;
|
||||
use tokio::runtime::Handle;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
||||
const PROVIDER_ID: &str = "amazon-bedrock";
|
||||
const PROVIDER_NAME: &str = "Amazon Bedrock";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
|
||||
|
||||
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
|
||||
pub struct BedrockCredentials {
|
||||
@@ -285,11 +284,11 @@ impl BedrockLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for BedrockLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -460,22 +459,22 @@ impl BedrockModel {
|
||||
&self,
|
||||
request: bedrock::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<
|
||||
BoxFuture<'static, BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>>,
|
||||
> {
|
||||
let runtime_client = self
|
||||
.get_or_init_client(cx)
|
||||
let Ok(runtime_client) = self
|
||||
.get_or_init_client(&cx)
|
||||
.cloned()
|
||||
.context("Bedrock client not initialized")?;
|
||||
let owned_handle = self.handler.clone();
|
||||
.context("Bedrock client not initialized")
|
||||
else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
Ok(async move {
|
||||
let request = bedrock::stream_completion(runtime_client, request, owned_handle);
|
||||
request.await.unwrap_or_else(|e| {
|
||||
futures::stream::once(async move { Err(BedrockError::ClientError(e)) }).boxed()
|
||||
})
|
||||
match Tokio::spawn(cx, bedrock::stream_completion(runtime_client, request)) {
|
||||
Ok(res) => async { res.await.map_err(|err| anyhow!(err))? }.boxed(),
|
||||
Err(err) => futures::future::ready(Err(anyhow!(err))).boxed(),
|
||||
}
|
||||
.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,11 +488,11 @@ impl LanguageModel for BedrockModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -504,10 +503,6 @@ impl LanguageModel for BedrockModel {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0 // Bedrock models don't currently support images in this implementation
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
|
||||
@@ -574,12 +569,10 @@ impl LanguageModel for BedrockModel {
|
||||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||
};
|
||||
|
||||
let owned_handle = self.handler.clone();
|
||||
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.map_err(|err| anyhow!(err))?.await;
|
||||
let events = map_to_language_model_completion_events(response, owned_handle);
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let events = map_to_language_model_completion_events(response);
|
||||
|
||||
if deny_tool_calls {
|
||||
Ok(deny_tool_use_events(events).boxed())
|
||||
@@ -883,7 +876,6 @@ pub fn get_bedrock_tokens(
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
handle: Handle,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
@@ -896,198 +888,123 @@ pub fn map_to_language_model_completion_events(
|
||||
tool_uses_by_index: HashMap<i32, RawToolUse>,
|
||||
}
|
||||
|
||||
futures::stream::unfold(
|
||||
State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
},
|
||||
move |mut state: State| {
|
||||
let inner_handle = handle.clone();
|
||||
async move {
|
||||
inner_handle
|
||||
.spawn(async {
|
||||
while let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => match event {
|
||||
ConverseStreamOutput::ContentBlockDelta(cb_delta) => {
|
||||
match cb_delta.delta {
|
||||
Some(ContentBlockDelta::Text(text_out)) => {
|
||||
let completion_event =
|
||||
LanguageModelCompletionEvent::Text(text_out);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
}
|
||||
let initial_state = State {
|
||||
events,
|
||||
tool_uses_by_index: HashMap::default(),
|
||||
};
|
||||
|
||||
Some(ContentBlockDelta::ToolUse(text_out)) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.get_mut(&cb_delta.content_block_index)
|
||||
{
|
||||
tool_use.input_json.push_str(text_out.input());
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContentBlockDelta::ReasoningContent(thinking)) => {
|
||||
match thinking {
|
||||
ReasoningContentBlockDelta::RedactedContent(
|
||||
redacted,
|
||||
) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: String::from_utf8(
|
||||
redacted.into_inner(),
|
||||
)
|
||||
.unwrap_or("REDACTED".to_string()),
|
||||
signature: None,
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Signature(
|
||||
signature,
|
||||
) => {
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".to_string(),
|
||||
signature: Some(signature)
|
||||
})),
|
||||
state,
|
||||
));
|
||||
}
|
||||
ReasoningContentBlockDelta::Text(thoughts) => {
|
||||
let thinking_event =
|
||||
LanguageModelCompletionEvent::Thinking {
|
||||
text: thoughts.to_string(),
|
||||
signature: None
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(thinking_event)),
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::ContentBlockStart(cb_start) => {
|
||||
if let Some(ContentBlockStart::ToolUse(text_out)) =
|
||||
cb_start.start
|
||||
{
|
||||
let tool_use = RawToolUse {
|
||||
id: text_out.tool_use_id,
|
||||
name: text_out.name,
|
||||
input_json: String::new(),
|
||||
};
|
||||
|
||||
state
|
||||
.tool_uses_by_index
|
||||
.insert(cb_start.content_block_index, tool_use);
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::ContentBlockStop(cb_stop) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.remove(&cb_stop.content_block_index)
|
||||
{
|
||||
let tool_use_event = LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input: if tool_use.input_json.is_empty() {
|
||||
Value::Null
|
||||
} else {
|
||||
serde_json::Value::from_str(
|
||||
&tool_use.input_json,
|
||||
)
|
||||
.map_err(|err| anyhow!(err))
|
||||
.unwrap()
|
||||
},
|
||||
};
|
||||
|
||||
return Some((
|
||||
Some(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
tool_use_event,
|
||||
))),
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
ConverseStreamOutput::Metadata(cb_meta) => {
|
||||
if let Some(metadata) = cb_meta.usage {
|
||||
let completion_event =
|
||||
LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens:
|
||||
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
|
||||
cache_read_input_tokens:
|
||||
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
|
||||
},
|
||||
);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
}
|
||||
}
|
||||
ConverseStreamOutput::MessageStop(message_stop) => {
|
||||
let reason = match message_stop.stop_reason {
|
||||
StopReason::ContentFiltered => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::EndTurn => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::GuardrailIntervened => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::MaxTokens => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::StopSequence => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
)
|
||||
}
|
||||
StopReason::ToolUse => {
|
||||
LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::ToolUse,
|
||||
)
|
||||
}
|
||||
_ => LanguageModelCompletionEvent::Stop(
|
||||
language_model::StopReason::EndTurn,
|
||||
),
|
||||
};
|
||||
return Some((Some(Ok(reason)), state));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
|
||||
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
|
||||
futures::stream::unfold(initial_state, |mut state| async move {
|
||||
match state.events.next().await {
|
||||
Some(event_result) => match event_result {
|
||||
Ok(event) => {
|
||||
let result = match event {
|
||||
ConverseStreamOutput::ContentBlockDelta(cb_delta) => match cb_delta.delta {
|
||||
Some(ContentBlockDelta::Text(text)) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(text)))
|
||||
}
|
||||
Some(ContentBlockDelta::ToolUse(tool_output)) => {
|
||||
if let Some(tool_use) = state
|
||||
.tool_uses_by_index
|
||||
.get_mut(&cb_delta.content_block_index)
|
||||
{
|
||||
tool_use.input_json.push_str(tool_output.input());
|
||||
}
|
||||
None
|
||||
}
|
||||
Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
|
||||
ReasoningContentBlockDelta::Text(thoughts) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: thoughts.clone(),
|
||||
signature: None,
|
||||
}))
|
||||
}
|
||||
ReasoningContentBlockDelta::Signature(sig) => {
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "".into(),
|
||||
signature: Some(sig),
|
||||
}))
|
||||
}
|
||||
ReasoningContentBlockDelta::RedactedContent(redacted) => {
|
||||
let content = String::from_utf8(redacted.into_inner())
|
||||
.unwrap_or("REDACTED".to_string());
|
||||
Some(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: content,
|
||||
signature: None,
|
||||
}))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
},
|
||||
ConverseStreamOutput::ContentBlockStart(cb_start) => {
|
||||
if let Some(ContentBlockStart::ToolUse(tool_start)) = cb_start.start {
|
||||
state.tool_uses_by_index.insert(
|
||||
cb_start.content_block_index,
|
||||
RawToolUse {
|
||||
id: tool_start.tool_use_id,
|
||||
name: tool_start.name,
|
||||
input_json: String::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
None
|
||||
}
|
||||
None
|
||||
})
|
||||
.await
|
||||
.log_err()
|
||||
.flatten()
|
||||
}
|
||||
},
|
||||
)
|
||||
.filter_map(|event| async move { event })
|
||||
ConverseStreamOutput::ContentBlockStop(cb_stop) => state
|
||||
.tool_uses_by_index
|
||||
.remove(&cb_stop.content_block_index)
|
||||
.map(|tool_use| {
|
||||
let input = if tool_use.input_json.is_empty() {
|
||||
Value::Null
|
||||
} else {
|
||||
serde_json::Value::from_str(&tool_use.input_json)
|
||||
.unwrap_or(Value::Null)
|
||||
};
|
||||
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input,
|
||||
},
|
||||
))
|
||||
}),
|
||||
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens: metadata
|
||||
.cache_write_input_tokens
|
||||
.unwrap_or_default()
|
||||
as u64,
|
||||
cache_read_input_tokens: metadata
|
||||
.cache_read_input_tokens
|
||||
.unwrap_or_default()
|
||||
as u64,
|
||||
}))
|
||||
}),
|
||||
ConverseStreamOutput::MessageStop(message_stop) => {
|
||||
let stop_reason = match message_stop.stop_reason {
|
||||
StopReason::ToolUse => language_model::StopReason::ToolUse,
|
||||
_ => language_model::StopReason::EndTurn,
|
||||
};
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason)))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Some((result, state))
|
||||
}
|
||||
Err(err) => Some((
|
||||
Some(Err(LanguageModelCompletionError::Other(anyhow!(err)))),
|
||||
state,
|
||||
)),
|
||||
},
|
||||
None => None,
|
||||
}
|
||||
})
|
||||
.filter_map(|result| async move { result })
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
||||
use anthropic::AnthropicModelMode;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
@@ -8,25 +9,21 @@ use google_ai::GoogleModelMode;
|
||||
use gpui::{
|
||||
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
|
||||
};
|
||||
use http_client::http::{HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
|
||||
RefreshLlmTokenListener,
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
|
||||
};
|
||||
use proto::Plan;
|
||||
use release_channel::AppVersion;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use settings::SettingsStore;
|
||||
use smol::Timer;
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
@@ -47,7 +44,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
|
||||
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
|
||||
|
||||
pub const PROVIDER_NAME: &str = "Zed";
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct ZedDotDevSettings {
|
||||
@@ -120,7 +118,7 @@ pub struct State {
|
||||
llm_api_token: LlmApiToken,
|
||||
user_store: Entity<UserStore>,
|
||||
status: client::Status,
|
||||
accept_terms: Option<Task<Result<()>>>,
|
||||
accept_terms_of_service_task: Option<Task<Result<()>>>,
|
||||
models: Vec<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
@@ -144,7 +142,7 @@ impl State {
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
user_store,
|
||||
status,
|
||||
accept_terms: None,
|
||||
accept_terms_of_service_task: None,
|
||||
models: Vec::new(),
|
||||
default_model: None,
|
||||
default_fast_model: None,
|
||||
@@ -253,12 +251,12 @@ impl State {
|
||||
|
||||
fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
|
||||
let user_store = self.user_store.clone();
|
||||
self.accept_terms = Some(cx.spawn(async move |this, cx| {
|
||||
self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
|
||||
let _ = user_store
|
||||
.update(cx, |store, cx| store.accept_terms_of_service(cx))?
|
||||
.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.accept_terms = None;
|
||||
this.accept_terms_of_service_task = None;
|
||||
cx.notify()
|
||||
})
|
||||
}));
|
||||
@@ -351,11 +349,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -397,7 +395,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
!self.state.read(cx).is_signed_out()
|
||||
let state = self.state.read(cx);
|
||||
!state.is_signed_out() && state.has_accepted_terms_of_service(cx)
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
@@ -405,10 +404,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|_| ConfigurationView {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
.into()
|
||||
cx.new(|_| ConfigurationView::new(self.state.clone()))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn must_accept_terms(&self, cx: &App) -> bool {
|
||||
@@ -420,7 +417,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
view: LanguageModelProviderTosView,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
render_accept_terms(self.state.clone(), view, cx)
|
||||
let state = self.state.read(cx);
|
||||
if state.has_accepted_terms_of_service(cx) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
|
||||
let state = self.state.clone();
|
||||
move |_window, cx| {
|
||||
state.update(cx, |state, cx| state.accept_terms_of_service(cx));
|
||||
}
|
||||
})
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
@@ -429,18 +438,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn render_accept_terms(
|
||||
state: Entity<State>,
|
||||
view_kind: LanguageModelProviderTosView,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
if state.read(cx).has_accepted_terms_of_service(cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let accept_terms_disabled = state.read(cx).accept_terms.is_some();
|
||||
|
||||
accept_terms_of_service_in_progress: bool,
|
||||
accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
|
||||
) -> impl IntoElement {
|
||||
let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
|
||||
let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
|
||||
let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
|
||||
|
||||
let terms_button = Button::new("terms_of_service", "Terms of Service")
|
||||
.style(ButtonStyle::Subtle)
|
||||
@@ -463,18 +466,11 @@ fn render_accept_terms(
|
||||
this.style(ButtonStyle::Tinted(TintColor::Warning))
|
||||
.label_size(LabelSize::Small)
|
||||
})
|
||||
.disabled(accept_terms_disabled)
|
||||
.on_click({
|
||||
let state = state.downgrade();
|
||||
move |_, _window, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.accept_terms_of_service(cx))
|
||||
.ok();
|
||||
}
|
||||
}),
|
||||
.disabled(accept_terms_of_service_in_progress)
|
||||
.on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
|
||||
);
|
||||
|
||||
let form = if thread_empty_state {
|
||||
if thread_empty_state {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.flex_wrap()
|
||||
@@ -512,12 +508,10 @@ fn render_accept_terms(
|
||||
LanguageModelProviderTosView::ThreadFreshStart => {
|
||||
button_container.w_full().justify_center()
|
||||
}
|
||||
LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
|
||||
LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
Some(form.into_any())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
@@ -536,8 +530,6 @@ struct PerformLlmCompletionResponse {
|
||||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
@@ -547,8 +539,7 @@ impl CloudLanguageModel {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut retries_remaining = Self::MAX_RETRIES;
|
||||
let mut retry_delay = Duration::from_secs(1);
|
||||
let mut refreshed_token = false;
|
||||
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder()
|
||||
@@ -590,14 +581,20 @@ impl CloudLanguageModel {
|
||||
includes_status_messages,
|
||||
tool_use_limit_reached,
|
||||
});
|
||||
} else if response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
}
|
||||
|
||||
if !refreshed_token
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
retries_remaining -= 1;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else if status == StatusCode::FORBIDDEN
|
||||
refreshed_token = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if status == StatusCode::FORBIDDEN
|
||||
&& response
|
||||
.headers()
|
||||
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
||||
@@ -622,35 +619,18 @@ impl CloudLanguageModel {
|
||||
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Forbidden");
|
||||
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
|
||||
// If we encounter an error in the 500 range, retry after a delay.
|
||||
// We've seen at least these in the wild from API providers:
|
||||
// * 500 Internal Server Error
|
||||
// * 502 Bad Gateway
|
||||
// * 529 Service Overloaded
|
||||
|
||||
if retries_remaining == 0 {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
anyhow::bail!(
|
||||
"cloud language model completion failed after {} retries with status {status}: {body}",
|
||||
Self::MAX_RETRIES
|
||||
);
|
||||
}
|
||||
|
||||
Timer::after(retry_delay).await;
|
||||
|
||||
retries_remaining -= 1;
|
||||
retry_delay *= 2; // If it fails again, wait longer.
|
||||
} else if status == StatusCode::PAYMENT_REQUIRED {
|
||||
return Err(anyhow!(PaymentRequiredError));
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(ApiError { status, body }));
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
let headers = response.headers().clone();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(ApiError {
|
||||
status,
|
||||
body,
|
||||
headers
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -660,6 +640,19 @@ impl CloudLanguageModel {
|
||||
struct ApiError {
|
||||
status: StatusCode,
|
||||
body: String,
|
||||
headers: HeaderMap<HeaderValue>,
|
||||
}
|
||||
|
||||
impl From<ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
error.body,
|
||||
retry_after,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for CloudLanguageModel {
|
||||
@@ -672,11 +665,29 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
|
||||
Google => language_model::GOOGLE_PROVIDER_ID,
|
||||
}
|
||||
}
|
||||
|
||||
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
|
||||
Google => language_model::GOOGLE_PROVIDER_NAME,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -699,18 +710,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||
self.model.supports_max_mode
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_images {
|
||||
// Use a conservative limit that works across all providers
|
||||
// Anthropic has the smallest limit at 5 MiB
|
||||
// Anthropic documentation: https://docs.anthropic.com/en/docs/build-with-claude/vision#faq
|
||||
// "API: Maximum 5MB per image"
|
||||
5_242_880 // 5 MiB
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("zed.dev/{}", self.model.id)
|
||||
}
|
||||
@@ -788,6 +787,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
.body(serde_json::to_string(&request_body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
let mut response_body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
@@ -802,7 +802,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||
} else {
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
body: response_body
|
||||
body: response_body,
|
||||
headers
|
||||
}))
|
||||
}
|
||||
}
|
||||
@@ -867,18 +868,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err.downcast::<ApiError>() {
|
||||
Ok(api_err) => {
|
||||
if api_err.status == StatusCode::BAD_REQUEST {
|
||||
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
|
||||
return anyhow!(
|
||||
LanguageModelKnownError::ContextWindowLimitExceeded {
|
||||
tokens
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
anyhow!(api_err)
|
||||
}
|
||||
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
|
||||
Err(err) => anyhow!(err),
|
||||
})?;
|
||||
|
||||
@@ -1007,7 +997,7 @@ where
|
||||
.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Err(error) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(error))]
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
Ok(CloudCompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
@@ -1066,32 +1056,24 @@ fn response_lines<T: DeserializeOwned>(
|
||||
)
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Entity<State>,
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
struct ZedAIConfiguration {
|
||||
is_connected: bool,
|
||||
plan: Option<proto::Plan>,
|
||||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
eligible_for_trial: bool,
|
||||
has_accepted_terms_of_service: bool,
|
||||
accept_terms_of_service_in_progress: bool,
|
||||
accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) {
|
||||
self.state.update(cx, |state, cx| {
|
||||
state.authenticate(cx).detach_and_log_err(cx);
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
impl RenderOnce for ZedAIConfiguration {
|
||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
|
||||
|
||||
let is_connected = !self.state.read(cx).is_signed_out();
|
||||
let user_store = self.state.read(cx).user_store.read(cx);
|
||||
let plan = user_store.current_plan();
|
||||
let subscription_period = user_store.subscription_period();
|
||||
let eligible_for_trial = user_store.trial_started_at().is_none();
|
||||
let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
|
||||
|
||||
let is_pro = plan == Some(proto::Plan::ZedPro);
|
||||
let subscription_text = match (plan, subscription_period) {
|
||||
let is_pro = self.plan == Some(proto::Plan::ZedPro);
|
||||
let subscription_text = match (self.plan, self.subscription_period) {
|
||||
(Some(proto::Plan::ZedPro), Some(_)) => {
|
||||
"You have access to Zed's hosted LLMs through your Zed Pro subscription."
|
||||
}
|
||||
@@ -1102,7 +1084,7 @@ impl Render for ConfigurationView {
|
||||
"You have basic access to Zed's hosted LLMs through your Zed Free subscription."
|
||||
}
|
||||
_ => {
|
||||
if eligible_for_trial {
|
||||
if self.eligible_for_trial {
|
||||
"Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
|
||||
} else {
|
||||
"Subscribe for access to Zed's hosted LLMs."
|
||||
@@ -1113,7 +1095,7 @@ impl Render for ConfigurationView {
|
||||
h_flex().child(
|
||||
Button::new("manage_settings", "Manage Subscription")
|
||||
.style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
)
|
||||
} else {
|
||||
h_flex()
|
||||
@@ -1121,28 +1103,38 @@ impl Render for ConfigurationView {
|
||||
.child(
|
||||
Button::new("learn_more", "Learn more")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
|
||||
.on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
|
||||
)
|
||||
.child(
|
||||
Button::new("upgrade", "Upgrade")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.color(Color::Accent)
|
||||
.on_click(
|
||||
cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
),
|
||||
Button::new(
|
||||
"upgrade",
|
||||
if self.plan.is_none() && self.eligible_for_trial {
|
||||
"Start Trial"
|
||||
} else {
|
||||
"Upgrade"
|
||||
},
|
||||
)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.color(Color::Accent)
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
)
|
||||
};
|
||||
|
||||
if is_connected {
|
||||
if self.is_connected {
|
||||
v_flex()
|
||||
.gap_3()
|
||||
.w_full()
|
||||
.children(render_accept_terms(
|
||||
self.state.clone(),
|
||||
LanguageModelProviderTosView::Configuration,
|
||||
cx,
|
||||
))
|
||||
.when(has_accepted_terms, |this| {
|
||||
.when(!self.has_accepted_terms_of_service, |this| {
|
||||
this.child(render_accept_terms(
|
||||
LanguageModelProviderTosView::Configuration,
|
||||
self.accept_terms_of_service_in_progress,
|
||||
{
|
||||
let callback = self.accept_terms_of_service_callback.clone();
|
||||
move |window, cx| (callback)(window, cx)
|
||||
},
|
||||
))
|
||||
})
|
||||
.when(self.has_accepted_terms_of_service, |this| {
|
||||
this.child(subscription_text)
|
||||
.child(manage_subscription_buttons)
|
||||
})
|
||||
@@ -1155,8 +1147,126 @@ impl Render for ConfigurationView {
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
|
||||
.on_click({
|
||||
let callback = self.sign_in_callback.clone();
|
||||
move |_, window, cx| (callback)(window, cx)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: Entity<State>,
|
||||
accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: Entity<State>) -> Self {
|
||||
let accept_terms_of_service_callback = Arc::new({
|
||||
let state = state.clone();
|
||||
move |_window: &mut Window, cx: &mut App| {
|
||||
state.update(cx, |state, cx| {
|
||||
state.accept_terms_of_service(cx);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let sign_in_callback = Arc::new({
|
||||
let state = state.clone();
|
||||
move |_window: &mut Window, cx: &mut App| {
|
||||
state.update(cx, |state, cx| {
|
||||
state.authenticate(cx).detach_and_log_err(cx);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
state,
|
||||
accept_terms_of_service_callback,
|
||||
sign_in_callback,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let state = self.state.read(cx);
|
||||
let user_store = state.user_store.read(cx);
|
||||
|
||||
ZedAIConfiguration {
|
||||
is_connected: !state.is_signed_out(),
|
||||
plan: user_store.current_plan(),
|
||||
subscription_period: user_store.subscription_period(),
|
||||
eligible_for_trial: user_store.trial_started_at().is_none(),
|
||||
has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
|
||||
accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
|
||||
accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
|
||||
sign_in_callback: self.sign_in_callback.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for ZedAIConfiguration {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
fn configuration(
|
||||
is_connected: bool,
|
||||
plan: Option<proto::Plan>,
|
||||
eligible_for_trial: bool,
|
||||
has_accepted_terms_of_service: bool,
|
||||
) -> AnyElement {
|
||||
ZedAIConfiguration {
|
||||
is_connected,
|
||||
plan,
|
||||
subscription_period: plan
|
||||
.is_some()
|
||||
.then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
|
||||
eligible_for_trial,
|
||||
has_accepted_terms_of_service,
|
||||
accept_terms_of_service_in_progress: false,
|
||||
accept_terms_of_service_callback: Arc::new(|_, _| {}),
|
||||
sign_in_callback: Arc::new(|_, _| {}),
|
||||
}
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.p_4()
|
||||
.gap_4()
|
||||
.children(vec![
|
||||
single_example("Not connected", configuration(false, None, false, true)),
|
||||
single_example(
|
||||
"Accept Terms of Service",
|
||||
configuration(true, None, true, false),
|
||||
),
|
||||
single_example(
|
||||
"No Plan - Not eligible for trial",
|
||||
configuration(true, None, false, true),
|
||||
),
|
||||
single_example(
|
||||
"No Plan - Eligible for trial",
|
||||
configuration(true, None, true, true),
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
configuration(true, Some(proto::Plan::Free), true, true),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial Plan",
|
||||
configuration(true, Some(proto::Plan::ZedProTrial), true, true),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Plan",
|
||||
configuration(true, Some(proto::Plan::ZedPro), true, true),
|
||||
),
|
||||
])
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
|
||||
const PROVIDER_ID: &str = "copilot_chat";
|
||||
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
|
||||
const PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("GitHub Copilot Chat");
|
||||
|
||||
pub struct CopilotChatLanguageModelProvider {
|
||||
state: Entity<State>,
|
||||
@@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -216,17 +217,6 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
self.model.supports_vision()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_vision() {
|
||||
// OpenAI documentation: https://help.openai.com/en/articles/8983719-what-are-the-file-upload-size-restrictions
|
||||
// "For images, there's a limit of 20MB per image."
|
||||
// GitHub Copilot uses OpenAI models under the hood
|
||||
20_971_520 // 20 MB - GitHub Copilot uses OpenAI models
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
match self.model.vendor() {
|
||||
ModelVendor::OpenAI | ModelVendor::Anthropic => {
|
||||
@@ -402,24 +392,24 @@ pub fn map_to_language_model_completion_events(
|
||||
serde_json::Value::from_str(&tool_call.arguments)
|
||||
};
|
||||
match arguments {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
},
|
||||
),
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "deepseek";
|
||||
const PROVIDER_NAME: &str = "DeepSeek";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
|
||||
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -302,10 +302,6 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||
false
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
0 // DeepSeek models don't currently support images
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("deepseek/{}", self.model.id())
|
||||
}
|
||||
@@ -470,7 +466,7 @@ impl DeepSeekEventMapper {
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -480,7 +476,7 @@ impl DeepSeekEventMapper {
|
||||
event: deepseek::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
@@ -542,8 +538,8 @@ impl DeepSeekEventMapper {
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_call.id.clone().into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
|
||||
@@ -37,8 +37,8 @@ use util::ResultExt;
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::ui::InstructionListItem;
|
||||
|
||||
const PROVIDER_ID: &str = "google";
|
||||
const PROVIDER_NAME: &str = "Google AI";
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct GoogleSettings {
|
||||
@@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -349,17 +349,6 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_images() {
|
||||
// Google Gemini documentation: https://ai.google.dev/gemini-api/docs/image-understanding
|
||||
// "Note: Inline image data limits your total request size (text prompts, system instructions, and inline bytes) to 20MB."
|
||||
// "For larger requests, upload image files using the File API."
|
||||
20_971_520 // 20 MB - Google Gemini's file API limit
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto
|
||||
@@ -434,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
);
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request
|
||||
.await
|
||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
||||
let response = request.await.map_err(LanguageModelCompletionError::from)?;
|
||||
Ok(GoogleEventMapper::new().map_stream(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
@@ -633,7 +620,7 @@ impl GoogleEventMapper {
|
||||
futures::stream::iter(match event {
|
||||
Some(Ok(event)) => self.map_event(event),
|
||||
Some(Err(error)) => {
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
|
||||
})
|
||||
|
||||
@@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
|
||||
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
|
||||
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
|
||||
|
||||
const PROVIDER_ID: &str = "lmstudio";
|
||||
const PROVIDER_NAME: &str = "LM Studio";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct LmStudioSettings {
|
||||
@@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -410,18 +410,6 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||
self.model.supports_images
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_images {
|
||||
// LM Studio documentation: https://lmstudio.ai/docs/typescript/llm-prediction/image-input
|
||||
// While not explicitly stated, LM Studio uses a standard 20MB limit
|
||||
// matching OpenAI's documented limit: https://help.openai.com/en/articles/8983719-what-are-the-file-upload-size-restrictions
|
||||
// "For images, there's a limit of 20MB per image."
|
||||
20_971_520 // 20 MB - Default limit for local models
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("lmstudio/{}", self.model.id())
|
||||
}
|
||||
@@ -486,7 +474,7 @@ impl LmStudioEventMapper {
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -496,7 +484,7 @@ impl LmStudioEventMapper {
|
||||
event: lmstudio::ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.into_iter().next() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
@@ -565,7 +553,7 @@ impl LmStudioEventMapper {
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
|
||||
@@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
@@ -15,6 +14,7 @@ use language_model::{
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use mistral::StreamResponse;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
@@ -29,8 +29,8 @@ use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct MistralSettings {
|
||||
@@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for MistralLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -317,18 +317,6 @@ impl LanguageModel for MistralLanguageModel {
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_images() {
|
||||
// Mistral documentation: https://www.infoq.com/news/2025/03/mistral-ai-ocr-api/
|
||||
// "The API is currently limited to files that do not exceed 50MB in size or 1,000 pages"
|
||||
// Also confirmed in https://github.com/everaldo/mcp-mistral-ocr/blob/master/README.md
|
||||
// "Maximum file size: 50MB (enforced by Mistral API)"
|
||||
52_428_800 // 50 MB - Mistral's OCR API limit
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("mistral/{}", self.model.id())
|
||||
}
|
||||
@@ -591,13 +579,13 @@ impl MistralEventMapper {
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
|
||||
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -607,7 +595,7 @@ impl MistralEventMapper {
|
||||
event: mistral::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
return vec![Err(LanguageModelCompletionError::from(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
@@ -672,7 +660,7 @@ impl MistralEventMapper {
|
||||
|
||||
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
||||
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
||||
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
results.push(Err(LanguageModelCompletionError::from(anyhow!(
|
||||
"Received incomplete tool call: missing id or name"
|
||||
))));
|
||||
continue;
|
||||
@@ -688,12 +676,14 @@ impl MistralEventMapper {
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
))),
|
||||
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})),
|
||||
Err(error) => {
|
||||
results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
|
||||
const OLLAMA_SITE: &str = "https://ollama.com/";
|
||||
|
||||
const PROVIDER_ID: &str = "ollama";
|
||||
const PROVIDER_NAME: &str = "Ollama";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct OllamaSettings {
|
||||
@@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -365,16 +365,6 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
self.model.supports_vision.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn max_image_size(&self) -> u64 {
|
||||
if self.model.supports_vision.unwrap_or(false) {
|
||||
// Ollama documentation: https://github.com/ollama/ollama/releases/tag/v0.1.15
|
||||
// "Images up to 100MB in size are supported."
|
||||
104_857_600 // 100 MB - Ollama's documented API limit
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto => false,
|
||||
@@ -463,7 +453,7 @@ fn map_to_language_model_completion_events(
|
||||
let delta = match response {
|
||||
Ok(delta) => delta,
|
||||
Err(e) => {
|
||||
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
|
||||
let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
|
||||
return Some((vec![event], state));
|
||||
}
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user