Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Benfield
30919d8187 tool use conversion to streaming in progress
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-12-09 16:11:37 -08:00
2 changed files with 208 additions and 80 deletions

View File

@@ -1,6 +1,8 @@
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use anyhow::anyhow;
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
@@ -11,12 +13,14 @@ use futures::{
channel::mpsc,
future::{LocalBoxFuture, Shared},
join,
stream::BoxStream,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role, TokenUsage,
report_assistant_event,
};
use multi_buffer::MultiBufferRow;
@@ -390,9 +394,15 @@ impl CodegenAlternative {
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
let tool_use =
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
let completion_events =
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
self.generation = self.handle_completion(
telemetry_id,
provider_id.to_string(),
api_key,
completion_events,
cx,
);
} else {
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
@@ -404,7 +414,8 @@ impl CodegenAlternative {
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
self.generation =
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
}
Ok(())
@@ -586,6 +597,21 @@ impl CodegenAlternative {
}))
}
// stream: impl Future<Output = Result<InlineAssistantStream>>
// impl Stream for InlineAssistantStream {
// type Output = InlineAssistantChunk
// }
//
// enum InlineAssistantChunk {
// rewrite_text(String)
// Error(Err)
// }
// explanation_text(String)
//
//
//
// handle_completion_stream
pub fn handle_stream(
&mut self,
model_telemetry_id: String,
@@ -593,7 +619,7 @@ impl CodegenAlternative {
model_api_key: Option<String>,
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let start_time = Instant::now();
// Make a new snapshot and re-resolve anchor in case the document was modified.
@@ -647,7 +673,8 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let token_usage = stream
@@ -673,6 +700,42 @@ impl CodegenAlternative {
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
// impl Stream<Output = Result<String>>;
// struct StreamingDiffLoop {
// diff: StreamingDiff,
// line_diff: LineDiff,
// new_text: String,
// base_indent: Option<usize>,
// line_indent: Option<usize>,
// first_line: bool,
// }
// impl StreamingDiffLoop {
// fn new(selected_text: &str) -> Self {
// Self {
// diff: StreamingDiff::new(selected_text.to_string()),
// line_diff: LineDiff::default(),
// new_text: String::new(),
// base_indent: None,
// line_indent: None,
// first_line: true,
// }
// }
// }
// let diff_loop = StreamingDiffLoop::new(selected_text.to_string());
// while let Some(chunk) = chunks.next().await {
// if response_latency.is_none() {
// response_latency = Some(request_start.elapsed());
// }
// let chunk = chunk?;
// completion_clone.lock().push_str(&chunk);
// diff_loop.push(chunk, suggested_line_indent, selection_start, selected_text, diff_tx);
// }
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -864,8 +927,7 @@ impl CodegenAlternative {
cx.notify();
})
.ok();
});
cx.notify();
})
}
pub fn stop(&mut self, cx: &mut Context<Self>) {
@@ -1040,21 +1102,29 @@ impl CodegenAlternative {
})
}
fn handle_tool_use(
fn handle_completion(
&mut self,
_telemetry_id: String,
_provider_id: String,
_api_key: Option<String>,
tool_use: impl 'static
+ Future<
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
telemetry_id: String,
provider_id: String,
api_key: Option<String>,
completion_stream: Task<
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
// Leaving this in generation so that STOP equivalent events are respected even
// while we're still pre-processing the completion event
cx.spawn(async move |codegen, cx| {
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
let _ = codegen.update(cx, |this, cx| {
this.status = status;
@@ -1063,76 +1133,130 @@ impl CodegenAlternative {
});
};
let tool_use = tool_use.await;
let mut completion_events = match completion_stream.await {
Ok(events) => events,
Err(err) => {
finish_with_status(CodegenStatus::Error(err.into()), cx);
return;
}
};
match tool_use {
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
// Parse the input JSON into RewriteSectionInput
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
Ok(input) => {
// Store the description if non-empty
let description = if !input.description.trim().is_empty() {
Some(input.description.clone())
} else {
None
};
let chars_read_so_far = Arc::new(Mutex::new(0usize));
let tool_to_text = move |tool_use: LanguageModelToolUse| -> String {
let mut chars_read_so_far = chars_read_so_far.lock();
dbg!(&tool_use);
let input: RewriteSectionInput =
serde_json::from_value(tool_use.input.clone()).unwrap();
let value = input.replacement_text[*chars_read_so_far..].to_string();
*chars_read_so_far = value.len();
value
};
// Apply the replacement text to the buffer and compute diff
let batch_diff_task = codegen
.update(cx, |this, cx| {
this.model_explanation = description.map(Into::into);
let range = this.range.clone();
this.apply_edits(
std::iter::once((range, input.replacement_text)),
cx,
);
this.reapply_batch_diff(cx)
})
.ok();
let mut message_id = None;
let mut first_text = None;
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
let total_text = Arc::new(Mutex::new(String::new()));
// Wait for the diff computation to complete
if let Some(diff_task) = batch_diff_task {
diff_task.await;
}
finish_with_status(CodegenStatus::Done, cx);
return;
loop {
if let Some(first_event) = completion_events.next().await {
dbg!(&first_event);
match first_event {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
dbg!("AAA 0");
message_id = Some(id);
}
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if tool_use.name.as_ref() == "rewrite_section" =>
{
dbg!("AAA 1");
first_text = Some(tool_to_text(tool_use));
break;
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
}
Ok(e) => {
log::warn!("Unexpected event: {:?}", e);
break;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
break;
}
}
}
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
// Handle failure message tool use
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
Ok(input) => {
let _ = codegen.update(cx, |this, _cx| {
// Store the failure message as the tool description
this.model_explanation = Some(input.message.into());
});
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
}
Ok(_tool_use) => {
// Unexpected tool.
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
});
cx.notify();
let text = total_text.lock().clone();
dbg!(text);
let Some(first_text) = first_text else {
finish_with_status(
CodegenStatus::Error(anyhow!("Failed to start????").into()),
cx,
);
return;
};
let move_last_token_usage = last_token_usage.clone();
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
completion_events.filter_map(move |e| {
let tool_to_text = tool_to_text.clone();
let last_token_usage = move_last_token_usage.clone();
let total_text = total_text.clone();
async move {
match e {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if tool_use.name.as_ref() == "rewrite_section" =>
{
Some(Ok(tool_to_text(tool_use)))
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
None
}
e => {
println!("UNEXPECTED EVENT {:?}", e);
None
}
}
}
}),
));
let language_model_text_stream = LanguageModelTextStream {
message_id: message_id,
stream: text_stream,
last_token_usage,
};
let Some(task) = codegen
.update(cx, move |codegen, cx| {
codegen.handle_stream(
telemetry_id,
provider_id,
api_key,
async { Ok(language_model_text_stream) },
cx,
)
})
.ok()
else {
return;
};
task.await;
})
}
}
@@ -1659,7 +1783,7 @@ mod tests {
) -> mpsc::UnboundedSender<String> {
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
codegen.generation = codegen.handle_stream(
String::new(),
String::new(),
None,

View File

@@ -16,4 +16,8 @@ pub struct InlineAssistantV2FeatureFlag;
impl FeatureFlag for InlineAssistantV2FeatureFlag {
const NAME: &'static str = "inline-assistant-v2";
fn enabled_for_staff() -> bool {
true
}
}