Compare commits

...

7 Commits

Author SHA1 Message Date
Piotr Osiewicz
02cc7f01a2 chore: Fix warning from upcoming Rust 1.81 release 2024-09-04 21:14:22 +02:00
Piotr Osiewicz
8d4bdd6dc6 lsp: Fill in version for SnippetEdit from drive (#17360)
Related to #16680 

Release Notes:

- N/A
2024-09-04 19:31:32 +02:00
Marshall Bowers
30b2133336 language_model: Add tool results to message content (#17363)
This PR updates the message content for an LLM request to allow it
contain tool results.

Release Notes:

- N/A
2024-09-04 13:29:01 -04:00
David Soria Parra
74907cb3e6 context_servers: Pass env variables from settings (#17356)
Users can now pass an env dictionary of string: string mappings to a
context server binary.

Release Notes:

- context_servers: Settings now allow the configuration of env variables
that are passed to the server process
2024-09-04 12:34:43 -04:00
Marshall Bowers
f38956943b assistant: Propagate LLM stop reason upwards (#17358)
This PR makes it so we propagate the `stop_reason` from Anthropic up to
the Assistant so that we can take action based on it.

The `extract_content_from_events` function was moved from `anthropic` to
the `anthropic` module in `language_model` since it is more useful if it
is able to name the `LanguageModelCompletionEvent` type, as otherwise
we'd need an additional layer of plumbing.

Release Notes:

- N/A
2024-09-04 12:31:10 -04:00
Mathias
7c8f62e943 Add hard_tabs: false in project settings (#17357)
# Problem

I have a custom system-wide rustfmt configuration, and use tabs over
spaces. So when I contribute to Zed, I will get lots of formatting
errors.

# Proposition

- ~~Add rustfmt.toml (to specify that you are using the default rustfmt
configuration, see https://github.com/rust-lang/cargo/issues/14442)~~
- Add `hard_tabs: false` to `.zed/settings.json` for people using tabs
over spaces.

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
2024-09-04 12:30:28 -04:00
Joseph T Lyons
bc39ca06a7 v0.153.x dev 2024-09-04 10:30:40 -04:00
13 changed files with 243 additions and 203 deletions

View File

@@ -38,6 +38,7 @@
}
}
},
"hard_tabs": false,
"formatter": "auto",
"remove_trailing_whitespace_on_save": true,
"ensure_final_newline_on_save": true

3
Cargo.lock generated
View File

@@ -243,7 +243,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"collections",
"futures 0.3.30",
"http_client",
"isahc",
@@ -14212,7 +14211,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.152.0"
version = "0.153.0"
dependencies = [
"activity_indicator",
"anyhow",

View File

@@ -18,7 +18,6 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
chrono.workspace = true
collections.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true

View File

@@ -5,7 +5,6 @@ use std::{pin::Pin, str::FromStr};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::HashMap;
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
@@ -13,7 +12,7 @@ use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
use util::{maybe, ResultExt as _};
use util::ResultExt as _;
pub use supported_countries::*;
@@ -332,94 +331,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
pub fn extract_content_from_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((Some(Ok(ResponseContent::Text { text })), state));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(ResponseContent::ToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
})
})),
state,
));
}
}
Event::Error { error } => {
return Some((Some(Err(AnthropicError::ApiError(error))), state));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(err)), state));
}
}
}
None
},
)
.filter_map(|event| async move { event })
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
@@ -512,6 +423,14 @@ pub enum RequestContent {
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
is_error: bool,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
#[derive(Debug, Serialize, Deserialize)]

View File

@@ -1999,6 +1999,11 @@ impl Context {
});
match event {
LanguageModelCompletionEvent::Stop(reason) => match reason {
language_model::StopReason::ToolUse => {}
language_model::StopReason::EndTurn => {}
language_model::StopReason::MaxTokens => {}
},
LanguageModelCompletionEvent::Text(chunk) => {
buffer.edit(
[(

View File

@@ -39,6 +39,7 @@ pub struct ServerConfig {
pub id: String,
pub executable: String,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
impl Settings for ContextServerSettings {
@@ -70,13 +71,13 @@ impl ContextServer {
}
async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
log::info!("starting context server {}", self.config.id);
log::info!("starting context server {}", self.config.id,);
let client = Client::new(
client::ContextServerId(self.config.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&self.config.executable).to_path_buf(),
args: self.config.args.clone(),
env: None,
env: self.config.env.clone(),
},
cx.clone(),
)?;

View File

@@ -26,7 +26,7 @@ pub(crate) mod dispatch_sys {
use dispatch_sys::*;
pub(crate) fn dispatch_get_main_queue() -> dispatch_queue_t {
unsafe { addr_of!(_dispatch_main_q) as *const _ as dispatch_queue_t }
addr_of!(_dispatch_main_q) as *const _ as dispatch_queue_t
}
pub(crate) struct MacDispatcher {

View File

@@ -55,10 +55,19 @@ pub struct LanguageModelCacheConfiguration {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
Stop(StopReason),
Text(String),
ToolUse(LanguageModelToolUse),
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LanguageModelToolUse {
pub id: String,
@@ -112,6 +121,7 @@ pub trait LanguageModel: Send + Sync {
.filter_map(|result| async move {
match result {
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Err(err) => Some(Err(err)),
}

View File

@@ -3,11 +3,12 @@ use crate::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse};
use anthropic::AnthropicError;
use crate::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use anthropic::{AnthropicError, ContentDelta, Event, ResponseContent};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use collections::{BTreeMap, HashMap};
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
@@ -17,11 +18,13 @@ use http_client::HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use util::{maybe, ResultExt};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@@ -258,12 +261,15 @@ pub fn count_anthropic_tokens(
for content in message.content {
match content {
MessageContent::Text(string) => {
string_contents.push_str(&string);
MessageContent::Text(text) => {
string_contents.push_str(&text);
}
MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolResult(tool_result) => {
string_contents.push_str(&tool_result.content);
}
}
}
@@ -371,30 +377,9 @@ impl LanguageModel for AnthropicModel {
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
Ok(anthropic::extract_content_from_events(response))
Ok(map_to_language_model_completion_events(response))
});
async move {
Ok(future
.await?
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id,
name,
input,
})
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
@@ -443,6 +428,120 @@ impl LanguageModel for AnthropicModel {
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
struct RawToolUse {
id: String,
name: String,
input_json: String,
}
struct State {
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
tool_uses_by_index: HashMap<usize, RawToolUse>,
}
futures::stream::unfold(
State {
events,
tool_uses_by_index: HashMap::default(),
},
|mut state| async move {
while let Some(event) = state.events.next().await {
match event {
Ok(event) => match event {
Event::ContentBlockStart {
index,
content_block,
} => match content_block {
ResponseContent::Text { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
state,
));
}
ResponseContent::ToolUse { id, name, .. } => {
state.tool_uses_by_index.insert(
index,
RawToolUse {
id,
name,
input_json: String::new(),
},
);
return Some((None, state));
}
},
Event::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => {
return Some((
Some(Ok(LanguageModelCompletionEvent::Text(text))),
state,
));
}
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((None, state));
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
Some(maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id,
name: tool_use.name,
input: serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?,
},
))
})),
state,
));
}
}
Event::MessageDelta { delta, .. } => {
if let Some(stop_reason) = delta.stop_reason.as_deref() {
let stop_reason = match stop_reason {
"end_turn" => StopReason::EndTurn,
"max_tokens" => StopReason::MaxTokens,
"tool_use" => StopReason::ToolUse,
_ => StopReason::EndTurn,
};
return Some((
Some(Ok(LanguageModelCompletionEvent::Stop(stop_reason))),
state,
));
}
}
Event::Error { error } => {
return Some((
Some(Err(anyhow!(AnthropicError::ApiError(error)))),
state,
));
}
_ => {}
},
Err(err) => {
return Some((Some(Err(anyhow!(err))), state));
}
}
}
None
},
)
.filter_map(|event| async move { event })
}
struct ConfigurationView {
api_key_editor: View<Editor>,
state: gpui::Model<State>,

View File

@@ -1,4 +1,5 @@
use super::open_ai::count_open_ai_tokens;
use crate::provider::anthropic::map_to_language_model_completion_events;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
@@ -33,10 +34,7 @@ use std::{
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
use crate::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
LanguageModelToolUse,
};
use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
use super::anthropic::count_anthropic_tokens;
@@ -518,30 +516,11 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
Ok(anthropic::extract_content_from_events(Box::pin(
Ok(map_to_language_model_completion_events(Box::pin(
response_lines(response).map_err(AnthropicError::Other),
)))
});
async move {
Ok(future
.await?
.map(|result| {
result
.map(|content| match content {
anthropic::ResponseContent::Text { text } => {
LanguageModelCompletionEvent::Text(text)
}
anthropic::ResponseContent::ToolUse { id, name, input } => {
LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse { id, name, input },
)
}
})
.map_err(|err| anyhow!(err))
})
.boxed())
}
.boxed()
async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();

View File

@@ -8,14 +8,24 @@ use serde::{Deserialize, Serialize};
use ui::{px, SharedString};
use util::ResultExt;
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)]
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct LanguageModelImage {
// A base64 encoded PNG image
/// A base64-encoded PNG image.
pub source: SharedString,
size: Size<DevicePixels>,
}
const ANTHROPIC_SIZE_LIMT: f32 = 1568.0; // Anthropic wants uploaded images to be smaller than this in both dimensions
impl std::fmt::Debug for LanguageModelImage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelImage")
.field("source", &format!("<{} bytes>", self.source.len()))
.field("size", &self.size)
.finish()
}
}
/// Anthropic wants uploaded images to be smaller than this in both dimensions.
const ANTHROPIC_SIZE_LIMT: f32 = 1568.;
impl LanguageModelImage {
pub fn from_image(data: Image, cx: &mut AppContext) -> Task<Option<Self>> {
@@ -67,7 +77,7 @@ impl LanguageModelImage {
}
}
// SAFETY: The base64 encoder should not produce non-UTF8
// SAFETY: The base64 encoder should not produce non-UTF8.
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
@@ -77,7 +87,7 @@ impl LanguageModelImage {
})
}
/// Resolves image into an LLM-ready format (base64)
/// Resolves image into an LLM-ready format (base64).
pub fn from_render_image(data: &RenderImage) -> Option<Self> {
let image_size = data.size(0);
@@ -130,7 +140,7 @@ impl LanguageModelImage {
base64_encoder.write_all(png.as_slice()).log_err()?;
}
// SAFETY: The base64 encoder should not produce non-UTF8
// SAFETY: The base64 encoder should not produce non-UTF8.
let source = unsafe { String::from_utf8_unchecked(base64_image) };
Some(LanguageModelImage {
@@ -144,35 +154,32 @@ impl LanguageModelImage {
let height = self.size.height.0.unsigned_abs() as usize;
// From: https://docs.anthropic.com/en/docs/build-with-claude/vision#calculate-image-costs
// Note that are a lot of conditions on anthropic's API, and OpenAI doesn't use this,
// so this method is more of a rough guess
// Note that are a lot of conditions on Anthropic's API, and OpenAI doesn't use this,
// so this method is more of a rough guess.
(width * height) / 750
}
}
#[derive(Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LanguageModelToolResult {
pub tool_use_id: String,
pub is_error: bool,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum MessageContent {
Text(String),
Image(LanguageModelImage),
}
impl std::fmt::Debug for MessageContent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageContent::Text(t) => f.debug_struct("MessageContent").field("text", t).finish(),
MessageContent::Image(i) => f
.debug_struct("MessageContent")
.field("image", &i.source.len())
.finish(),
}
}
ToolResult(LanguageModelToolResult),
}
impl MessageContent {
pub fn as_string(&self) -> &str {
match self {
MessageContent::Text(s) => s.as_str(),
MessageContent::Text(text) => text.as_str(),
MessageContent::Image(_) => "",
MessageContent::ToolResult(tool_result) => tool_result.content.as_str(),
}
}
}
@@ -200,8 +207,9 @@ impl LanguageModelRequestMessage {
pub fn string_contents(&self) -> String {
let mut string_buffer = String::new();
for string in self.content.iter().filter_map(|content| match content {
MessageContent::Text(s) => Some(s),
MessageContent::Text(text) => Some(text),
MessageContent::Image(_) => None,
MessageContent::ToolResult(tool_result) => Some(&tool_result.content),
}) {
string_buffer.push_str(string.as_str())
}
@@ -214,8 +222,11 @@ impl LanguageModelRequestMessage {
.content
.get(0)
.map(|content| match content {
MessageContent::Text(s) => s.trim().is_empty(),
MessageContent::Text(text) => text.trim().is_empty(),
MessageContent::Image(_) => true,
MessageContent::ToolResult(tool_result) => {
tool_result.content.trim().is_empty()
}
})
.unwrap_or(false)
}
@@ -316,21 +327,34 @@ impl LanguageModelRequest {
.content
.into_iter()
.filter_map(|content| match content {
MessageContent::Text(t) if !t.is_empty() => {
Some(anthropic::RequestContent::Text {
text: t,
MessageContent::Text(text) => {
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control,
})
} else {
None
}
}
MessageContent::Image(image) => {
Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control,
})
}
MessageContent::ToolResult(tool_result) => {
Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id,
is_error: tool_result.is_error,
content: tool_result.content,
cache_control,
})
}
MessageContent::Image(i) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: i.source.to_string(),
},
cache_control,
}),
_ => None,
})
.collect();
let anthropic_role = match message.role {

View File

@@ -5715,11 +5715,10 @@ impl LspStore {
}
}
if !snippet_edits.is_empty() {
if let Some(buffer_version) = op.text_document.version {
let buffer_id = buffer_to_edit.read(cx).remote_id();
// Check if the edit that triggered that edit has been made by this participant.
let most_recent_edit = this
.buffer_snapshots
let buffer_id = buffer_to_edit.read(cx).remote_id();
let version = if let Some(buffer_version) = op.text_document.version
{
this.buffer_snapshots
.get(&buffer_id)
.and_then(|server_to_snapshots| {
let all_snapshots = server_to_snapshots
@@ -5731,17 +5730,22 @@ impl LspStore {
.ok()
.and_then(|index| all_snapshots.get(index))
})
.and_then(|lsp_snapshot| {
let version = lsp_snapshot.snapshot.version();
version.iter().max_by_key(|timestamp| timestamp.value)
});
if let Some(most_recent_edit) = most_recent_edit {
cx.emit(LspStoreEvent::SnippetEdit {
buffer_id,
edits: snippet_edits,
most_recent_edit,
});
}
.map(|lsp_snapshot| lsp_snapshot.snapshot.version())
} else {
Some(buffer_to_edit.read(cx).saved_version())
};
let most_recent_edit = version.and_then(|version| {
version.iter().max_by_key(|timestamp| timestamp.value)
});
// Check if the edit that triggered that edit has been made by this participant.
if let Some(most_recent_edit) = most_recent_edit {
cx.emit(LspStoreEvent::SnippetEdit {
buffer_id,
edits: snippet_edits,
most_recent_edit,
});
}
}

View File

@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition = "2021"
name = "zed"
version = "0.152.0"
version = "0.153.0"
publish = false
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]