Compare commits
1 Commits
nightly
...
inline-ass
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30919d8187 |
@@ -1,6 +1,8 @@
|
||||
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
@@ -11,12 +13,14 @@ use futures::{
|
||||
channel::mpsc,
|
||||
future::{LocalBoxFuture, Shared},
|
||||
join,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role, TokenUsage,
|
||||
report_assistant_event,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
@@ -390,9 +394,15 @@ impl CodegenAlternative {
|
||||
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
let tool_use =
|
||||
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
|
||||
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
|
||||
let completion_events =
|
||||
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
|
||||
self.generation = self.handle_completion(
|
||||
telemetry_id,
|
||||
provider_id.to_string(),
|
||||
api_key,
|
||||
completion_events,
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
@@ -404,7 +414,8 @@ impl CodegenAlternative {
|
||||
})
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
self.generation =
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -586,6 +597,21 @@ impl CodegenAlternative {
|
||||
}))
|
||||
}
|
||||
|
||||
// stream: impl Future<Output = Result<InlineAssistantStream>>
|
||||
// impl Stream for InlineAssistantStream {
|
||||
// type Output = InlineAssistantChunk
|
||||
// }
|
||||
//
|
||||
// enum InlineAssistantChunk {
|
||||
// rewrite_text(String)
|
||||
// Error(Err)
|
||||
// }
|
||||
// explanation_text(String)
|
||||
//
|
||||
//
|
||||
//
|
||||
// handle_completion_stream
|
||||
|
||||
pub fn handle_stream(
|
||||
&mut self,
|
||||
model_telemetry_id: String,
|
||||
@@ -593,7 +619,7 @@ impl CodegenAlternative {
|
||||
model_api_key: Option<String>,
|
||||
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Task<()> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Make a new snapshot and re-resolve anchor in case the document was modified.
|
||||
@@ -647,7 +673,8 @@ impl CodegenAlternative {
|
||||
let completion = Arc::new(Mutex::new(String::new()));
|
||||
let completion_clone = completion.clone();
|
||||
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
cx.notify();
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
let stream = stream.await;
|
||||
|
||||
let token_usage = stream
|
||||
@@ -673,6 +700,42 @@ impl CodegenAlternative {
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
// impl Stream<Output = Result<String>>;
|
||||
|
||||
// struct StreamingDiffLoop {
|
||||
// diff: StreamingDiff,
|
||||
// line_diff: LineDiff,
|
||||
// new_text: String,
|
||||
// base_indent: Option<usize>,
|
||||
// line_indent: Option<usize>,
|
||||
// first_line: bool,
|
||||
// }
|
||||
|
||||
// impl StreamingDiffLoop {
|
||||
// fn new(selected_text: &str) -> Self {
|
||||
// Self {
|
||||
// diff: StreamingDiff::new(selected_text.to_string()),
|
||||
// line_diff: LineDiff::default(),
|
||||
// new_text: String::new(),
|
||||
// base_indent: None,
|
||||
// line_indent: None,
|
||||
// first_line: true,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// let diff_loop = StreamingDiffLoop::new(selected_text.to_string());
|
||||
|
||||
// while let Some(chunk) = chunks.next().await {
|
||||
// if response_latency.is_none() {
|
||||
// response_latency = Some(request_start.elapsed());
|
||||
// }
|
||||
// let chunk = chunk?;
|
||||
// completion_clone.lock().push_str(&chunk);
|
||||
|
||||
// diff_loop.push(chunk, suggested_line_indent, selection_start, selected_text, diff_tx);
|
||||
// }
|
||||
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -864,8 +927,7 @@ impl CodegenAlternative {
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
|
||||
pub fn stop(&mut self, cx: &mut Context<Self>) {
|
||||
@@ -1040,21 +1102,29 @@ impl CodegenAlternative {
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use(
|
||||
fn handle_completion(
|
||||
&mut self,
|
||||
_telemetry_id: String,
|
||||
_provider_id: String,
|
||||
_api_key: Option<String>,
|
||||
tool_use: impl 'static
|
||||
+ Future<
|
||||
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
|
||||
telemetry_id: String,
|
||||
provider_id: String,
|
||||
api_key: Option<String>,
|
||||
completion_stream: Task<
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Task<()> {
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
cx.notify();
|
||||
// Leaving this in generation so that STOP equivalent events are respected even
|
||||
// while we're still pre-processing the completion event
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
|
||||
let _ = codegen.update(cx, |this, cx| {
|
||||
this.status = status;
|
||||
@@ -1063,76 +1133,130 @@ impl CodegenAlternative {
|
||||
});
|
||||
};
|
||||
|
||||
let tool_use = tool_use.await;
|
||||
let mut completion_events = match completion_stream.await {
|
||||
Ok(events) => events,
|
||||
Err(err) => {
|
||||
finish_with_status(CodegenStatus::Error(err.into()), cx);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match tool_use {
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
|
||||
// Parse the input JSON into RewriteSectionInput
|
||||
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
// Store the description if non-empty
|
||||
let description = if !input.description.trim().is_empty() {
|
||||
Some(input.description.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let chars_read_so_far = Arc::new(Mutex::new(0usize));
|
||||
let tool_to_text = move |tool_use: LanguageModelToolUse| -> String {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
dbg!(&tool_use);
|
||||
let input: RewriteSectionInput =
|
||||
serde_json::from_value(tool_use.input.clone()).unwrap();
|
||||
let value = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = value.len();
|
||||
value
|
||||
};
|
||||
|
||||
// Apply the replacement text to the buffer and compute diff
|
||||
let batch_diff_task = codegen
|
||||
.update(cx, |this, cx| {
|
||||
this.model_explanation = description.map(Into::into);
|
||||
let range = this.range.clone();
|
||||
this.apply_edits(
|
||||
std::iter::once((range, input.replacement_text)),
|
||||
cx,
|
||||
);
|
||||
this.reapply_batch_diff(cx)
|
||||
})
|
||||
.ok();
|
||||
let mut message_id = None;
|
||||
let mut first_text = None;
|
||||
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
|
||||
let total_text = Arc::new(Mutex::new(String::new()));
|
||||
|
||||
// Wait for the diff computation to complete
|
||||
if let Some(diff_task) = batch_diff_task {
|
||||
diff_task.await;
|
||||
}
|
||||
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
loop {
|
||||
if let Some(first_event) = completion_events.next().await {
|
||||
dbg!(&first_event);
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
dbg!("AAA 0");
|
||||
message_id = Some(id);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if tool_use.name.as_ref() == "rewrite_section" =>
|
||||
{
|
||||
dbg!("AAA 1");
|
||||
first_text = Some(tool_to_text(tool_use));
|
||||
break;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
}
|
||||
Ok(e) => {
|
||||
log::warn!("Unexpected event: {:?}", e);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
|
||||
// Handle failure message tool use
|
||||
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
// Store the failure message as the tool description
|
||||
this.model_explanation = Some(input.message.into());
|
||||
});
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_tool_use) => {
|
||||
// Unexpected tool.
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
|
||||
let text = total_text.lock().clone();
|
||||
dbg!(text);
|
||||
|
||||
let Some(first_text) = first_text else {
|
||||
finish_with_status(
|
||||
CodegenStatus::Error(anyhow!("Failed to start????").into()),
|
||||
cx,
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
let move_last_token_usage = last_token_usage.clone();
|
||||
|
||||
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
|
||||
completion_events.filter_map(move |e| {
|
||||
let tool_to_text = tool_to_text.clone();
|
||||
let last_token_usage = move_last_token_usage.clone();
|
||||
let total_text = total_text.clone();
|
||||
async move {
|
||||
match e {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if tool_use.name.as_ref() == "rewrite_section" =>
|
||||
{
|
||||
Some(Ok(tool_to_text(tool_use)))
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
None
|
||||
}
|
||||
e => {
|
||||
println!("UNEXPECTED EVENT {:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
let language_model_text_stream = LanguageModelTextStream {
|
||||
message_id: message_id,
|
||||
stream: text_stream,
|
||||
last_token_usage,
|
||||
};
|
||||
|
||||
let Some(task) = codegen
|
||||
.update(cx, move |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
telemetry_id,
|
||||
provider_id,
|
||||
api_key,
|
||||
async { Ok(language_model_text_stream) },
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
task.await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1659,7 +1783,7 @@ mod tests {
|
||||
) -> mpsc::UnboundedSender<String> {
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
codegen.generation = codegen.handle_stream(
|
||||
String::new(),
|
||||
String::new(),
|
||||
None,
|
||||
|
||||
@@ -16,4 +16,8 @@ pub struct InlineAssistantV2FeatureFlag;
|
||||
|
||||
impl FeatureFlag for InlineAssistantV2FeatureFlag {
|
||||
const NAME: &'static str = "inline-assistant-v2";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user