Compare commits

...

7 Commits

Author SHA1 Message Date
Antonio Scandurra
33d2c8f726 Checkpoint
Co-Authored-By: Nathan <nathan@zed.dev>
2024-07-29 16:21:17 +02:00
Antonio Scandurra
1e160f22ce WIP 2024-07-29 14:04:54 +02:00
Antonio Scandurra
fdb6f90abb Document tool schema
Co-Authored-By: Nathan <nathan@zed.dev>
2024-07-29 13:10:11 +02:00
Antonio Scandurra
d473c6892d Pass Config instead of individual API keys
Co-Authored-By: Nathan <nathan@zed.dev>
2024-07-29 12:58:29 +02:00
Nathan Sobo
a036b64fb9 Refactor language model API to separate endpoints
This commit refactors the language model API to have separate endpoints for
completion, streaming completion, and token counting. This change allows for:

1. Better type safety and clearer interfaces for each operation
2. Separate rate limiting for different types of requests
3. More flexibility in implementing provider-specific logic for each operation

The main changes include:
- Replacing QueryLanguageModel with CompleteWithLanguageModel,
  StreamCompleteWithLanguageModel, and CountLanguageModelTokens
- Updating the server to handle these new request types
- Adjusting rate limiting for each operation type
- Modifying the protocol buffer definitions to reflect the new structure

This refactoring improves the overall design and maintainability of the
language model integration in the project.
2024-07-29 12:52:43 +02:00
Antonio Scandurra
5fa3a8256c Use tool calling instead of XML parsing to generate edit operations
Co-Authored-By: Nathan <nathan@zed.dev>
2024-07-28 22:00:37 +02:00
Antonio Scandurra
6a9951769d Introduce tool calling for Anthropic models 2024-07-28 22:00:30 +02:00
22 changed files with 1152 additions and 853 deletions

13
Cargo.lock generated
View File

@@ -435,7 +435,6 @@ dependencies = [
"rand 0.8.5",
"regex",
"rope",
"roxmltree 0.20.0",
"schemars",
"search",
"semantic_index",
@@ -2641,7 +2640,9 @@ dependencies = [
"language_model",
"project",
"rand 0.8.5",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"text",
@@ -4237,7 +4238,7 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d"
dependencies = [
"roxmltree 0.19.0",
"roxmltree",
]
[[package]]
@@ -8918,12 +8919,6 @@ version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f"
[[package]]
name = "roxmltree"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97"
[[package]]
name = "rpc"
version = "0.1.0"
@@ -11877,7 +11872,7 @@ dependencies = [
"kurbo",
"log",
"pico-args",
"roxmltree 0.19.0",
"roxmltree",
"simplecss",
"siphasher 1.0.1",
"strict-num",

View File

@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration};
use std::time::Duration;
use strum::EnumIter;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
@@ -70,112 +70,53 @@ impl Model {
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<Response> {
let uri = format!("{api_url}/v1/messages");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "tools-2024-04-04")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
impl TryFrom<String> for Role {
type Error = anyhow::Error;
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
_ => Err(anyhow!("invalid role '{value}'")),
}
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: Response = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
}
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
message: ResponseMessage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
Ping {},
ContentBlockDelta {
index: u32,
delta: TextDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: ResponseMessage,
usage: Usage,
},
MessageStop {},
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
pub id: Option<String>,
pub role: Option<String>,
pub content: Option<Vec<String>>,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
) -> Result<BoxStream<'static, Result<Event>>> {
let request = StreamingRequest {
base: request,
stream: true,
};
let uri = format!("{api_url}/v1/messages");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
@@ -187,7 +128,9 @@ pub async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -212,7 +155,7 @@ pub async fn stream_completion(
let body_str = std::str::from_utf8(&body)?;
match serde_json::from_str::<ResponseEvent>(body_str) {
match serde_json::from_str::<Event>(body_str) {
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
@@ -227,16 +170,18 @@ pub async fn stream_completion(
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseEvent>>,
response: impl Stream<Item = Result<Event>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
ContentBlock::Text { text } => Some(Ok(text)),
Event::ContentBlockStart { content_block, .. } => match content_block {
Content::Text { text } => Some(Ok(text)),
_ => None,
},
ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
TextDelta::TextDelta { text } => Some(Ok(text)),
Event::ContentBlockDelta { delta, .. } => match delta {
ContentDelta::TextDelta { text } => Some(Ok(text)),
_ => None,
},
_ => None,
},
@@ -245,42 +190,162 @@ pub fn extract_text_from_events(
})
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use http::IsahcHttpClient;
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
// #[tokio::test]
// async fn stream_completion_success() {
// let http_client = IsahcHttpClient::new().unwrap();
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
// let request = Request {
// model: Model::Claude3Opus,
// messages: vec![RequestMessage {
// role: Role::User,
// content: "Ping".to_string(),
// }],
// stream: true,
// system: "Respond to ping with pong".to_string(),
// max_tokens: 4096,
// };
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Content {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { source: ImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
}
// let stream = stream_completion(
// &http_client,
// "https://api.anthropic.com",
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
// request,
// )
// .await
// .unwrap();
#[derive(Debug, Serialize, Deserialize)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
// stream
// .for_each(|event| async {
// match event {
// Ok(event) => println!("{:?}", event),
// Err(e) => eprintln!("Error: {:?}", e),
// }
// })
// .await;
// }
// }
#[derive(Debug, Serialize, Deserialize)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<Metadata>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct StreamingRequest {
#[serde(flatten)]
pub base: Request,
pub stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Metadata {
pub user_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Response {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: Role,
pub content: Vec<Content>,
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_sequence: Option<String>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Event {
#[serde(rename = "message_start")]
MessageStart { message: Response },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: Content,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: ContentDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta { delta: MessageDelta, usage: Usage },
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ApiError },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MessageDelta {
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}

View File

@@ -75,7 +75,6 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }

View File

@@ -1207,12 +1207,16 @@ impl ContextEditor {
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
if let Some(step) = self.active_edit_step.as_ref() {
InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in &step.assist_ids {
assistant.start_assist(*assist_id, cx);
}
!step.assist_ids.is_empty()
})
let assist_ids = step.assist_ids.clone();
cx.window_context().defer(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
for assist_id in assist_ids {
assistant.start_assist(assist_id, cx);
}
})
});
!step.assist_ids.is_empty()
} else {
false
}
@@ -1261,11 +1265,7 @@ impl ContextEditor {
.collect::<String>()
));
match &step.operations {
Some(EditStepOperations::Parsed {
operations,
raw_output,
}) => {
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
Some(EditStepOperations::Ready(operations)) => {
output.push_str("Parsed Operations:\n");
for op in operations {
output.push_str(&format!(" {:?}\n", op));
@@ -1769,13 +1769,12 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap()
};
let initial_text = suggestion.prepend_newline.then(|| "\n".into());
InlineAssistant::update_global(cx, |assistant, cx| {
assist_ids.push(assistant.suggest_assist(
&editor,
range,
description,
initial_text,
suggestion.initial_insertion,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -1837,9 +1836,11 @@ impl ContextEditor {
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
.unwrap()
};
let initial_text =
suggestion.prepend_newline.then(|| "\n".to_string());
inline_assist_suggestions.push((range, description, initial_text));
inline_assist_suggestions.push((
range,
description,
suggestion.initial_insertion,
));
}
}
}
@@ -1850,12 +1851,12 @@ impl ContextEditor {
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
cx.update(|cx| {
InlineAssistant::update_global(cx, |assistant, cx| {
for (range, description, initial_text) in inline_assist_suggestions {
for (range, description, initial_insertion) in inline_assist_suggestions {
assist_ids.push(assistant.suggest_assist(
&editor,
range,
description,
initial_text,
initial_insertion,
Some(workspace.clone()),
assistant_panel.upgrade().as_ref(),
cx,
@@ -2163,7 +2164,7 @@ impl ContextEditor {
let button_text = match self.edit_step_for_cursor(cx) {
Some(edit_step) => match &edit_step.operations {
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
Some(EditStepOperations::Ready(_)) => "Apply Changes",
None => "Send",
},
None => "Send",

View File

@@ -1,6 +1,6 @@
use crate::{
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
MessageId, MessageStatus,
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
LanguageModelCompletionProvider, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
use language_model::LanguageModelRequestMessage;
use language_model::{LanguageModelRequest, Role};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
cmp,
@@ -352,7 +352,7 @@ pub struct EditSuggestion {
pub range: Range<language::Anchor>,
/// If None, assume this is a suggestion to delete the range rather than transform it.
pub description: Option<String>,
pub prepend_newline: bool,
pub initial_insertion: Option<InitialInsertion>,
}
impl EditStep {
@@ -361,7 +361,7 @@ impl EditStep {
project: &Model<Project>,
cx: &AppContext,
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
return Task::ready(HashMap::default());
};
@@ -471,32 +471,28 @@ impl EditStep {
}
pub enum EditStepOperations {
Pending(Task<Result<()>>),
Parsed {
operations: Vec<EditOperation>,
raw_output: String,
},
Pending(Task<Option<()>>),
Ready(Vec<EditOperation>),
}
impl Debug for EditStepOperations {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
EditStepOperations::Parsed {
operations,
raw_output,
} => f
EditStepOperations::Ready(operations) => f
.debug_struct("EditStepOperations::Parsed")
.field("operations", operations)
.field("raw_output", raw_output)
.finish(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
/// A description of an operation to apply to one location in the codebase.
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
pub struct EditOperation {
/// The path to the file containing the relevant operation
pub path: String,
#[serde(flatten)]
pub kind: EditOperationKind,
}
@@ -523,7 +519,7 @@ impl EditOperation {
parse_status.changed().await?;
}
let prepend_newline = kind.prepend_newline();
let initial_insertion = kind.initial_insertion();
let suggestion_range = if let Some(symbol) = kind.symbol() {
let outline = buffer
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
@@ -601,39 +597,61 @@ impl EditOperation {
EditSuggestion {
range: suggestion_range,
description: kind.description().map(ToString::to_string),
prepend_newline,
initial_insertion,
},
))
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
#[serde(tag = "kind")]
pub enum EditOperationKind {
/// Rewrite the specified symbol in its entirely based on the given description.
Update {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Create a new file with the given path based on the given description.
Create {
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description before the specified symbol.
InsertSiblingBefore {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol based on the given description after the specified symbol.
InsertSiblingAfter {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the start.
PrependChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Insert a new symbol as a child of the specified symbol at the end.
AppendChild {
/// An optional full path to the symbol to be rewritten from the provided list.
/// If not provided, the edit should be applied at the top of the file.
symbol: Option<String>,
/// A brief one-line description of the change that should be applied.
description: String,
},
/// Delete the specified symbol.
Delete {
/// A full path to the symbol to be rewritten from the provided list.
symbol: String,
},
}
@@ -663,13 +681,13 @@ impl EditOperationKind {
}
}
pub fn prepend_newline(&self) -> bool {
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
match self {
Self::PrependChild { .. }
| Self::AppendChild { .. }
| Self::InsertSiblingAfter { .. }
| Self::InsertSiblingBefore { .. } => true,
_ => false,
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
_ => None,
}
}
}
@@ -1137,18 +1155,15 @@ impl Context {
.timer(Duration::from_millis(200))
.await;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? {
let token_count = token_count.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
})?;
}
anyhow::Ok(())
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify()
})
}
.log_err()
});
@@ -1304,7 +1319,24 @@ impl Context {
&self,
edit_step: &EditStep,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
) -> Task<Option<()>> {
#[derive(Debug, Deserialize, JsonSchema)]
struct EditTool {
/// A sequence of operations to apply to the codebase.
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
operations: Vec<EditOperation>,
}
impl LanguageModelTool for EditTool {
fn name() -> String {
"edit".into()
}
fn description() -> String {
"suggest edits to one or more locations in the codebase".into()
}
}
let mut request = self.to_completion_request(cx);
let edit_step_range = edit_step.source_range.clone();
let step_text = self
@@ -1313,160 +1345,41 @@ impl Context {
.text_for_range(edit_step_range.clone())
.collect::<String>();
cx.spawn(|this, mut cx| async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
cx.spawn(|this, mut cx| {
async move {
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
let mut prompt = prompt_store.operations_prompt();
prompt.push_str(&step_text);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
let raw_output = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
let tool_use = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.use_tool::<EditTool>(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
.await?;
let operations = Self::parse_edit_operations(&raw_output);
this.update(&mut cx, |this, cx| {
let step_index = this
.edit_steps
.binary_search_by(|step| {
step.source_range
.cmp(&edit_step_range, this.buffer.read(cx))
})
.map_err(|_| anyhow!("edit step not found"))?;
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
edit_step.operations = Some(EditStepOperations::Parsed {
operations,
raw_output,
});
cx.emit(ContextEvent::EditStepsChanged);
}
anyhow::Ok(())
})?
})
}
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
let Some(start_ix) = xml.find("<operations>") else {
return Vec::new();
};
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
return Vec::new();
};
let end_ix = end_ix + start_ix + "</operations>".len();
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
doc.map_or(Vec::new(), |doc| {
doc.root_element()
.children()
.map(|node| {
let tag_name = node.tag_name().name();
let path = node
.attribute("path")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'path'")
})?
.to_string();
let kind = match tag_name {
"update" => EditOperationKind::Update {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"create" => EditOperationKind::Create {
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"prepend_child" => EditOperationKind::PrependChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"append_child" => EditOperationKind::AppendChild {
symbol: node.attribute("symbol").map(String::from),
description: node
.attribute("description")
.with_context(|| {
format!(
"invalid node {node:?}, missing attribute 'description'"
)
})?
.to_string(),
},
"delete" => EditOperationKind::Delete {
symbol: node
.attribute("symbol")
.with_context(|| {
format!("invalid node {node:?}, missing attribute 'symbol'")
})?
.to_string(),
},
_ => return Err(anyhow!("invalid node {node:?}")),
};
anyhow::Ok(EditOperation { path, kind })
})
.filter_map(|op| op.log_err())
.collect()
}
.log_err()
})
}
@@ -3083,55 +2996,6 @@ mod tests {
}
}
#[test]
fn test_parse_edit_operations() {
let operations = indoc! {r#"
Here are the operations to make all fields of the Canvas struct private:
<operations>
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
</operations>
"#};
let parsed_operations = Context::parse_edit_operations(operations);
assert_eq!(
parsed_operations,
vec![
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub pixels".to_string(),
description: "Remove pub keyword from pixels field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub size".to_string(),
description: "Remove pub keyword from size field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub stride".to_string(),
description: "Remove pub keyword from stride field".to_string(),
},
},
EditOperation {
path: "font-kit/src/canvas.rs".to_string(),
kind: EditOperationKind::Update {
symbol: "pub struct Canvas pub format".to_string(),
description: "Remove pub keyword from format field".to_string(),
},
},
]
);
}
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);

View File

@@ -17,7 +17,7 @@ use editor::{
use fs::Fs;
use futures::{
channel::mpsc,
future::LocalBoxFuture,
future::{BoxFuture, LocalBoxFuture},
stream::{self, BoxStream},
SinkExt, Stream, StreamExt,
};
@@ -36,7 +36,7 @@ use similar::TextDiff;
use smol::future::FutureExt;
use std::{
cmp,
future::Future,
future::{self, Future},
mem,
ops::{Range, RangeInclusive},
pin::Pin,
@@ -46,7 +46,7 @@ use std::{
};
use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip};
use util::RangeExt;
use util::{RangeExt, ResultExt};
use workspace::{notifications::NotificationId, Toast, Workspace};
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
@@ -187,7 +187,13 @@ impl InlineAssistant {
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
assists.push((
assist_id,
range,
prompt_editor,
prompt_block_id,
end_block_id,
));
}
let editor_assists = self
@@ -195,7 +201,7 @@ impl InlineAssistant {
.entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new();
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
self.assists.insert(
assist_id,
InlineAssist::new(
@@ -206,6 +212,7 @@ impl InlineAssistant {
&prompt_editor,
prompt_block_id,
end_block_id,
range,
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
@@ -227,7 +234,7 @@ impl InlineAssistant {
editor: &View<Editor>,
mut range: Range<Anchor>,
initial_prompt: String,
initial_insertion: Option<String>,
initial_insertion: Option<InitialInsertion>,
workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>,
cx: &mut WindowContext,
@@ -239,22 +246,30 @@ impl InlineAssistant {
let assist_id = self.next_assist_id.post_inc();
let buffer = editor.read(cx).buffer().clone();
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
buffer.end_transaction(cx)
})
});
{
let snapshot = buffer.read(cx).read(cx);
range.start = range.start.bias_left(&buffer.read(cx).read(cx));
range.end = range.end.bias_right(&buffer.read(cx).read(cx));
let mut point_range = range.to_point(&snapshot);
if point_range.is_empty() {
point_range.start.column = 0;
point_range.end.column = 0;
} else {
point_range.start.column = 0;
if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
point_range.end.row -= 1;
}
point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
}
range.start = snapshot.anchor_before(point_range.start);
range.end = snapshot.anchor_after(point_range.end);
}
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
prepend_transaction_id,
initial_insertion,
self.telemetry.clone(),
cx,
)
@@ -295,6 +310,7 @@ impl InlineAssistant {
&prompt_editor,
prompt_block_id,
end_block_id,
range,
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
@@ -445,7 +461,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id];
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
let assist_range = assist.range.to_offset(&buffer);
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
{
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
@@ -473,7 +489,7 @@ impl InlineAssistant {
let buffer = editor.buffer().read(cx).snapshot(cx);
for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id];
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
let assist_range = assist.range.to_offset(&buffer);
if assist.decorations.is_some()
&& assist_range.contains(&selection.start)
&& assist_range.contains(&selection.end)
@@ -551,7 +567,7 @@ impl InlineAssistant {
assist.codegen.read(cx).status,
CodegenStatus::Error(_) | CodegenStatus::Done
) {
let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
let assist_range = assist.range.to_offset(&snapshot);
if edited_ranges
.iter()
.any(|range| range.overlaps(&assist_range))
@@ -721,7 +737,7 @@ impl InlineAssistant {
});
}
let position = assist.codegen.read(cx).range.start;
let position = assist.range.start;
editor.update(cx, |editor, cx| {
editor.change_selections(None, cx, |selections| {
selections.select_anchor_ranges([position..position])
@@ -740,8 +756,7 @@ impl InlineAssistant {
.0 as f32;
} else {
let snapshot = editor.snapshot(cx);
let codegen = assist.codegen.read(cx);
let start_row = codegen
let start_row = assist
.range
.start
.to_display_point(&snapshot.display_snapshot)
@@ -829,11 +844,7 @@ impl InlineAssistant {
return;
}
let Some(user_prompt) = assist
.decorations
.as_ref()
.map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
else {
let Some(user_prompt) = assist.user_prompt(cx) else {
return;
};
@@ -843,139 +854,19 @@ impl InlineAssistant {
self.prompt_history.pop_front();
}
let codegen = assist.codegen.clone();
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request = self.request_for_inline_assist(assist_id, cx);
let mut cx = cx.to_async();
async move {
let request = request.await?;
let chunks = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.stream_completion(request, cx)
})?
.await?;
Ok(chunks.boxed())
}
.boxed_local()
};
codegen.update(cx, |codegen, cx| {
codegen.start(telemetry_id, chunks, cx);
});
}
let assistant_panel_context = assist.assistant_panel_context(cx);
fn request_for_inline_assist(
&self,
assist_id: InlineAssistId,
cx: &mut WindowContext,
) -> Task<Result<LanguageModelRequest>> {
cx.spawn(|mut cx| async move {
let (user_prompt, context_request, project_name, buffer, range) =
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let assist = this.assists.get(&assist_id).context("invalid assist")?;
let decorations = assist.decorations.as_ref().context("invalid assist")?;
let editor = assist.editor.upgrade().context("invalid assist")?;
let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
let context_request = if assist.include_context {
assist.workspace.as_ref().and_then(|workspace| {
let workspace = workspace.upgrade()?.read(cx);
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
Some(
assistant_panel
.read(cx)
.active_context(cx)?
.read(cx)
.to_completion_request(cx),
)
})
} else {
None
};
let project_name = assist.workspace.as_ref().and_then(|workspace| {
let workspace = workspace.upgrade()?;
Some(
workspace
.read(cx)
.project()
.read(cx)
.worktree_root_names(cx)
.collect::<Vec<&str>>()
.join("/"),
)
});
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = assist.codegen.read(cx).range.clone();
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
})??;
let language = buffer.language_at(range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
} else {
Some(language.name())
}
} else {
None
};
// Higher Temperature increases the randomness of model outputs.
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
if language.as_ref() == "Markdown" {
1.0
} else {
0.5
}
} else {
1.0
};
let prompt = cx
.background_executor()
.spawn(async move {
let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(range.start);
let end = buffer.point_to_buffer_offset(range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
return Err(anyhow!("invalid transformation range"));
}
} else {
return Err(anyhow!("invalid transformation range"));
};
generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
})
.await?;
let mut messages = Vec::new();
if let Some(context_request) = context_request {
messages = context_request.messages;
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
Ok(LanguageModelRequest {
messages,
stop: vec!["|END|>".to_string()],
temperature,
assist
.codegen
.update(cx, |codegen, cx| {
codegen.start(
assist.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
})
})
.log_err();
}
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
@@ -1006,12 +897,11 @@ impl InlineAssistant {
let codegen = assist.codegen.read(cx);
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
if codegen.edit_position != codegen.range.end {
gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
}
gutter_pending_ranges
.push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
if codegen.range.start != codegen.edit_position {
gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
if let Some(edit_position) = codegen.edit_position {
gutter_transformed_ranges.push(assist.range.start..edit_position);
}
if assist.decorations.is_some() {
@@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
})
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum InitialInsertion {
NewlineBefore,
NewlineAfter,
}
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);
@@ -1629,24 +1525,20 @@ impl PromptEditor {
let assist_id = self.id;
self.pending_token_count = cx.spawn(|this, mut cx| async move {
cx.background_executor().timer(Duration::from_secs(1)).await;
let request = cx
let token_count = cx
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
inline_assistant.request_for_inline_assist(assist_id, cx)
})?
let assist = inline_assistant
.assists
.get(&assist_id)
.context("assist not found")?;
anyhow::Ok(assist.count_tokens(cx))
})??
.await?;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? {
let token_count = token_count.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
})
} else {
Ok(())
}
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
})
})
}
@@ -1855,6 +1747,7 @@ impl PromptEditor {
struct InlineAssist {
group_id: InlineAssistGroupId,
range: Range<Anchor>,
editor: WeakView<Editor>,
decorations: Option<InlineAssistDecorations>,
codegen: Model<Codegen>,
@@ -1873,6 +1766,7 @@ impl InlineAssist {
prompt_editor: &View<PromptEditor>,
prompt_block_id: CustomBlockId,
end_block_id: CustomBlockId,
range: Range<Anchor>,
codegen: Model<Codegen>,
workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
@@ -1888,6 +1782,7 @@ impl InlineAssist {
removed_line_block_ids: HashSet::default(),
end_block_id,
}),
range,
codegen: codegen.clone(),
workspace: workspace.clone(),
_subscriptions: vec![
@@ -1963,6 +1858,41 @@ impl InlineAssist {
],
}
}
fn user_prompt(&self, cx: &AppContext) -> Option<String> {
let decorations = self.decorations.as_ref()?;
Some(decorations.prompt_editor.read(cx).prompt(cx))
}
fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
if self.include_context {
let workspace = self.workspace.as_ref()?;
let workspace = workspace.upgrade()?.read(cx);
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
Some(
assistant_panel
.read(cx)
.active_context(cx)?
.read(cx)
.to_completion_request(cx),
)
} else {
None
}
}
pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
let Some(user_prompt) = self.user_prompt(cx) else {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
self.codegen.read(cx).count_tokens(
self.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
}
}
struct InlineAssistDecorations {
@@ -1982,16 +1912,15 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
range: Range<Anchor>,
edit_position: Anchor,
edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
prepend_transaction_id: Option<TransactionId>,
generation_transaction_id: Option<TransactionId>,
transaction_id: Option<TransactionId>,
status: CodegenStatus,
generation: Task<()>,
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
initial_insertion: Option<InitialInsertion>,
}
enum CodegenStatus {
@@ -2015,7 +1944,7 @@ impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
prepend_transaction_id: Option<TransactionId>,
initial_insertion: Option<InitialInsertion>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -2044,17 +1973,16 @@ impl Codegen {
Self {
buffer: buffer.clone(),
old_buffer,
edit_position: range.start,
range,
edit_position: None,
snapshot,
last_equal_ranges: Default::default(),
prepend_transaction_id,
generation_transaction_id: None,
transaction_id: None,
status: CodegenStatus::Idle,
generation: Task::ready(()),
diff: Diff::default(),
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_insertion,
}
}
@@ -2065,13 +1993,8 @@ impl Codegen {
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.generation_transaction_id == Some(*transaction_id) {
self.generation_transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
} else if self.prepend_transaction_id == Some(*transaction_id) {
self.prepend_transaction_id = None;
self.generation_transaction_id = None;
if self.transaction_id == Some(*transaction_id) {
self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
}
@@ -2082,19 +2005,152 @@ impl Codegen {
&self.last_equal_ranges
}
pub fn count_tokens(
&self,
edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
}
pub fn start(
&mut self,
telemetry_id: String,
mut edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
self.undo(cx);
// Handle initial insertion
self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
self.buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
let offset = edit_range.start.to_offset(&self.snapshot);
let edit_position;
match initial_insertion {
InitialInsertion::NewlineBefore => {
buffer.edit([(offset..offset, "\n\n")], None, cx);
self.snapshot = buffer.snapshot(cx);
edit_position = self.snapshot.anchor_after(offset + 1);
}
InitialInsertion::NewlineAfter => {
buffer.edit([(offset..offset, "\n")], None, cx);
self.snapshot = buffer.snapshot(cx);
edit_position = self.snapshot.anchor_after(offset);
}
}
self.edit_position = Some(edit_position);
edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
buffer.end_transaction(cx)
})
} else {
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
None
};
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model_telemetry_id()
.context("no active model")?;
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
.trim()
.to_lowercase()
== "delete"
{
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
let chunks =
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
Ok(())
}
fn build_request(
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
edit_range: Range<Anchor>,
cx: &AppContext,
) -> LanguageModelRequest {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
} else {
Some(language.name())
}
} else {
None
};
// Higher Temperature increases the randomness of model outputs.
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
if language.as_ref() == "Markdown" {
1.0
} else {
0.5
}
} else {
1.0
};
let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(edit_range.start);
let end = buffer.point_to_buffer_offset(edit_range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
panic!("invalid transformation range");
}
} else {
panic!("invalid transformation range");
};
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
let mut messages = Vec::new();
if let Some(context_request) = assistant_panel_context {
messages = context_request.messages;
}
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
LanguageModelRequest {
messages,
stop: vec!["|END|>".to_string()],
temperature,
}
}
pub fn handle_stream(
&mut self,
model_telemetry_id: String,
edit_range: Range<Anchor>,
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
cx: &mut ModelContext<Self>,
) {
let range = self.range.clone();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
.text_for_range(range.start..range.end)
.text_for_range(edit_range.start..edit_range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let selection_start = edit_range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection
let mut suggested_line_indent = snapshot
@@ -2105,7 +2161,7 @@ impl Codegen {
// If the first line in the selection does not have indentation, check the following lines
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
for row in selection_start.row..=range.end.to_point(&snapshot).row {
for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
// Prefer tabs if a line in the selection uses tabs as indentation
if line_indent.kind == IndentKind::Tab {
@@ -2116,19 +2172,13 @@ impl Codegen {
}
let telemetry = self.telemetry.clone();
self.edit_position = range.start;
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
if let Some(transaction_id) = self.generation_transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
let mut edit_start = edit_range.start.to_offset(&snapshot);
self.generation = cx.spawn(|this, mut cx| {
async move {
let chunks = stream.await;
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff: Task<anyhow::Result<()>> =
cx.background_executor().spawn(async move {
@@ -2218,7 +2268,7 @@ impl Codegen {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
telemetry_id,
model_telemetry_id,
response_latency,
error_message,
);
@@ -2262,13 +2312,13 @@ impl Codegen {
None,
cx,
);
this.edit_position = snapshot.anchor_after(edit_start);
this.edit_position = Some(snapshot.anchor_after(edit_start));
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = this.generation_transaction_id {
if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
@@ -2278,14 +2328,14 @@ impl Codegen {
)
});
} else {
this.generation_transaction_id = Some(transaction);
this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
this.update_diff(cx);
this.update_diff(edit_range.clone(), cx);
cx.notify();
})?;
}
@@ -2321,27 +2371,22 @@ impl Codegen {
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.prepend_transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
if let Some(transaction_id) = self.generation_transaction_id.take() {
if let Some(transaction_id) = self.transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
}
fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
if self.diff.task.is_some() {
self.diff.should_update = true;
} else {
self.diff.should_update = false;
let old_snapshot = self.snapshot.clone();
let old_range = self.range.to_point(&old_snapshot);
let old_range = edit_range.to_point(&old_snapshot);
let new_snapshot = self.buffer.read(cx).snapshot(cx);
let new_range = self.range.to_point(&new_snapshot);
let new_range = edit_range.to_point(&new_snapshot);
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
let (deleted_row_ranges, inserted_row_ranges) = cx
@@ -2422,7 +2467,7 @@ impl Codegen {
this.diff.inserted_row_ranges = inserted_row_ranges;
this.diff.task = None;
if this.diff.should_update {
this.update_diff(cx);
this.update_diff(edit_range, cx);
}
cx.notify();
})
@@ -2629,12 +2674,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
codegen.handle_stream(
String::new(),
range,
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2690,12 +2737,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2755,12 +2804,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
@@ -2819,12 +2870,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)

View File

@@ -734,29 +734,27 @@ impl PromptLibrary {
const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1);
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(
LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: body.to_string(),
}],
stop: Vec::new(),
temperature: 1.,
},
cx,
)
})? {
let token_count = token_count.await?;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(
LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: body.to_string(),
}],
stop: Vec::new(),
temperature: 1.,
},
cx,
)
})?
.await?;
this.update(&mut cx, |this, cx| {
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
prompt_editor.token_count = Some(token_count);
cx.notify();
})
} else {
Ok(())
}
this.update(&mut cx, |this, cx| {
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
prompt_editor.token_count = Some(token_count);
cx.notify();
})
}
.log_err()
});

View File

@@ -6,8 +6,7 @@ pub fn generate_content_prompt(
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
_project_name: Option<String>,
) -> anyhow::Result<String> {
) -> String {
let mut prompt = String::new();
let content_type = match language_name {
@@ -15,14 +14,16 @@ pub fn generate_content_prompt(
writeln!(
prompt,
"Here's a file of text that I'm going to ask you to make an edit to."
)?;
)
.unwrap();
"text"
}
Some(language_name) => {
writeln!(
prompt,
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
)?;
)
.unwrap();
"code"
}
};
@@ -70,7 +71,7 @@ pub fn generate_content_prompt(
write!(prompt, "</document>\n\n").unwrap();
if is_truncated {
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
}
if range.is_empty() {
@@ -107,7 +108,7 @@ pub fn generate_content_prompt(
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
}
Ok(prompt)
prompt
}
pub fn generate_terminal_assistant_prompt(

View File

@@ -707,18 +707,15 @@ impl PromptEditor {
inline_assistant.request_for_inline_assist(assist_id, cx)
})??;
if let Some(token_count) = cx.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})? {
let token_count = token_count.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
})
} else {
Ok(())
}
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
cx.notify();
})
})
}

View File

@@ -10,7 +10,7 @@ use crate::{
ServerId, UpdatedChannelMessage, User, UserId,
},
executor::Executor,
AppState, Error, RateLimit, RateLimiter, Result,
AppState, Config, Error, RateLimit, RateLimiter, Result,
};
use anyhow::{anyhow, bail, Context as _};
use async_tungstenite::tungstenite::{
@@ -605,17 +605,39 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
complete_with_language_model(request, response, session, &app_state.config)
.await
}
}
})
.add_streaming_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
complete_with_language_model(
request,
response,
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
app_state.config.anthropic_api_key.clone(),
)
let app_state = app_state.clone();
async move {
stream_complete_with_language_model(
request,
response,
session,
&app_state.config,
)
.await
}
}
})
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
}
})
.add_request_handler({
@@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
query: proto::QueryLanguageModel,
response: StreamingResponse<proto::QueryLanguageModel>,
request: proto::CompleteWithLanguageModel,
response: Response<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
anthropic_api_key: Option<Arc<str>>,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
}
Some(proto::LanguageModelRequestKind::CountTokens) => {
session
.rate_limiter
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
.await?;
}
None => Err(anyhow!("unknown request kind"))?,
}
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(query.provider) {
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
anthropic::complete(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CompleteWithLanguageModelResponse {
completion: serde_json::to_string(&result)?,
})?;
Ok(())
}
async fn stream_complete_with_language_model(
request: proto::StreamCompleteWithLanguageModel,
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
session
.rate_limiter
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Anthropic) => {
let api_key = config
.anthropic_api_key
.as_ref()
.context("no Anthropic AI API key configured on the server")?;
let mut chunks = anthropic::stream_completion(
session.http_client.as_ref(),
anthropic::ANTHROPIC_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
while let Some(event) = chunks.next().await {
let chunk = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelProvider::OpenAi) => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
let mut chunks = open_ai::stream_completion(
let api_key = config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let mut events = open_ai::stream_completion(
session.http_client.as_ref(),
open_ai::OPEN_AI_API_URL,
&api_key,
serde_json::from_str(&query.request)?,
api_key,
serde_json::from_str(&request.request)?,
None,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
Some(proto::LanguageModelProvider::Google) => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
match proto::LanguageModelRequestKind::from_i32(query.kind) {
Some(proto::LanguageModelRequestKind::Complete) => {
let mut chunks = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&chunk)?,
})?;
}
}
Some(proto::LanguageModelRequestKind::CountTokens) => {
let tokens_response = google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
serde_json::from_str(&query.request)?,
)
.await?;
response.send(proto::QueryLanguageModelResponse {
response: serde_json::to_string(&tokens_response)?,
})?;
}
None => Err(anyhow!("unknown request kind"))?,
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
let mut events = google_ai::stream_generate_content(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?;
while let Some(event) = events.next().await {
let event = event?;
response.send(proto::StreamCompleteWithLanguageModelResponse {
event: serde_json::to_string(&event)?,
})?;
}
}
None => return Err(anyhow!("unknown provider"))?,
@@ -4608,11 +4646,51 @@ async fn complete_with_language_model(
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
session
.rate_limiter
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Ok(())
}
struct CountLanguageModelTokensRateLimit;
impl RateLimit for CountLanguageModelTokensRateLimit {
fn capacity() -> usize {
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
@@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
fn db_name() -> &'static str {
"count-tokens-with-language-model"
"count-language-model-tokens"
}
}

View File

@@ -26,7 +26,9 @@ anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true

View File

@@ -3,10 +3,10 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
LanguageModelRequest, LanguageModelTool,
};
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{pin::Pin, sync::Arc, task::Poll};
use smol::{future::FutureExt, lock::{Semaphore, SemaphoreGuardArc}};
use std::{future, pin::Pin, sync::Arc, task::Poll};
use ui::Context;
pub fn init(cx: &mut AppContext) {
@@ -143,11 +143,11 @@ impl LanguageModelCompletionProvider {
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Option<BoxFuture<'static, Result<usize>>> {
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() {
Some(model.count_tokens(request, cx))
model.count_tokens(request, cx)
} else {
None
future::ready(Err(anyhow!("no active model"))).boxed()
}
}
@@ -183,6 +183,29 @@ impl LanguageModelCompletionProvider {
Ok(completion)
})
}
pub fn use_tool<T: LanguageModelTool>(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<T>> {
if let Some(language_model) = self.active_model() {
cx.spawn(|cx| async move {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request =
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
let response = request.await?;
Ok(serde_json::from_value(response)?)
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn active_model_telemetry_id(&self) -> Option<String> {
self.active_model.as_ref().map(|m| m.telemetry_id())
}
}
#[cfg(test)]

View File

@@ -16,6 +16,8 @@ pub use model::*;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
settings::init(cx);
@@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn use_tool(
&self,
request: LanguageModelRequest,
name: String,
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>>;
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
}
pub trait LanguageModelProvider: 'static {

View File

@@ -1,5 +1,9 @@
use anthropic::stream_completion;
use anyhow::{anyhow, Result};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -15,12 +19,6 @@ use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
@@ -188,6 +186,61 @@ pub fn count_anthropic_tokens(
.boxed()
}
impl AnthropicModel {
fn request_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<anthropic::Response>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await
}
.boxed()
}
fn stream_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<anthropic::Event>>>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = anthropic::stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
request.await
}
.boxed()
}
}
impl LanguageModel for AnthropicModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = request.into_anthropic(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let request = self.stream_completion(request, cx);
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
}
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let mut request = request.into_anthropic(self.model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
let response = self.request_completion(request, cx);
async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
}
struct AuthenticationPrompt {

View File

@@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use anyhow::{anyhow, Context as _, Result};
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
@@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel {
};
async move {
let request = serde_json::to_string(&request)?;
let response = client.request(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::CountTokens as i32,
request,
});
let response = response.await?;
let response =
serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
Ok(response.total_tokens)
let response = client
.request(proto::CountLanguageModelTokens {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(response.token_count as usize)
}
.boxed()
}
@@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_anthropic(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
Ok(anthropic::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_open_ai(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
Ok(open_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel {
let request = request.into_google(model.id().into());
async move {
let request = serde_json::to_string(&request)?;
let response = client.request_stream(proto::QueryLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
kind: proto::LanguageModelRequestKind::Complete as i32,
request,
});
let chunks = response.await?;
let stream = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Google as i32,
request,
})
.await?;
Ok(google_ai::extract_text_from_events(
chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
)
.boxed())
}
@@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel {
}
}
}
fn use_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
request.tools = vec![anthropic::Tool {
name: tool_name.clone(),
description: tool_description,
input_schema,
}];
async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response = serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
}
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
}
}
}
struct AuthenticationPrompt {

View File

@@ -1,15 +1,17 @@
use std::sync::{Arc, Mutex};
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
use anyhow::anyhow;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
use std::{
future,
sync::{Arc, Mutex},
};
use ui::WindowContext;
pub fn language_model_id() -> LanguageModelId {
@@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel {
.insert(serde_json::to_string(&request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}

View File

@@ -9,7 +9,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
struct AuthenticationPrompt {

View File

@@ -6,7 +6,7 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
@@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
struct DownloadOllamaMessage {

View File

@@ -9,7 +9,7 @@ use gpui::{
use http_client::HttpClient;
use open_ai::stream_completion;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel {
}
.boxed()
}
fn use_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
pub fn count_open_ai_tokens(

View File

@@ -106,19 +106,27 @@ impl LanguageModelRequest {
messages: new_messages
.into_iter()
.filter_map(|message| {
Some(anthropic::RequestMessage {
Some(anthropic::Message {
role: match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => return None,
},
content: message.content,
content: vec![anthropic::Content::Text {
text: message.content,
}],
})
})
.collect(),
stream: true,
max_tokens: 4092,
system: system_message,
system: Some(system_message),
tools: Vec::new(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),
temperature: None,
top_k: None,
top_p: None,
}
}
}

View File

@@ -194,8 +194,12 @@ message Envelope {
JoinHostedProject join_hosted_project = 164;
QueryLanguageModel query_language_model = 224;
QueryLanguageModelResponse query_language_model_response = 225; // current max
CompleteWithLanguageModel complete_with_language_model = 226;
CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
CountLanguageModelTokens count_language_model_tokens = 230;
CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
@@ -267,6 +271,7 @@ message Envelope {
reserved 158 to 161;
reserved 166 to 169;
reserved 224 to 225;
}
// Messages
@@ -2050,10 +2055,31 @@ enum LanguageModelRole {
reserved 3;
}
message QueryLanguageModel {
message CompleteWithLanguageModel {
LanguageModelProvider provider = 1;
LanguageModelRequestKind kind = 2;
string request = 3;
string request = 2;
}
message CompleteWithLanguageModelResponse {
string completion = 1;
}
message StreamCompleteWithLanguageModel {
LanguageModelProvider provider = 1;
string request = 2;
}
message StreamCompleteWithLanguageModelResponse {
string event = 1;
}
message CountLanguageModelTokens {
LanguageModelProvider provider = 1;
string request = 2;
}
message CountLanguageModelTokensResponse {
uint32 token_count = 1;
}
enum LanguageModelProvider {
@@ -2062,15 +2088,6 @@ enum LanguageModelProvider {
Google = 2;
}
enum LanguageModelRequestKind {
Complete = 0;
CountTokens = 1;
}
message QueryLanguageModelResponse {
string response = 1;
}
message GetCachedEmbeddings {
string model = 1;
repeated bytes digests = 2;

View File

@@ -294,8 +294,12 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
(QueryLanguageModel, Background),
(QueryLanguageModelResponse, Background),
(CompleteWithLanguageModel, Background),
(CompleteWithLanguageModelResponse, Background),
(StreamCompleteWithLanguageModel, Background),
(StreamCompleteWithLanguageModelResponse, Background),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@@ -463,7 +467,12 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
(QueryLanguageModel, QueryLanguageModelResponse),
(CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
(
StreamCompleteWithLanguageModel,
StreamCompleteWithLanguageModelResponse
),
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),