Compare commits
6 Commits
no-install
...
sumtree-v1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3870eeff32 | ||
|
|
74a6c78fff | ||
|
|
2be18a200f | ||
|
|
4e2907e296 | ||
|
|
b43153a99f | ||
|
|
57ea58c83e |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -244,7 +244,6 @@ dependencies = [
|
||||
"terminal",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
"unindent",
|
||||
|
||||
@@ -61,7 +61,6 @@ sqlez.workspace = true
|
||||
task.workspace = true
|
||||
telemetry.workspace = true
|
||||
terminal.workspace = true
|
||||
thiserror.workspace = true
|
||||
text.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -499,16 +499,6 @@ pub struct ToolCallAuthorization {
|
||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum CompletionError {
|
||||
#[error("max tokens")]
|
||||
MaxTokens,
|
||||
#[error("refusal")]
|
||||
Refusal,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
id: acp::SessionId,
|
||||
prompt_id: PromptId,
|
||||
@@ -1087,62 +1077,101 @@ impl Thread {
|
||||
_task: cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let mut update_title = None;
|
||||
let turn_result: Result<()> = async {
|
||||
let mut intent = CompletionIntent::UserPrompt;
|
||||
let turn_result: Result<StopReason> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})??;
|
||||
|
||||
log::info!("Calling model.stream_completion");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut refused = false;
|
||||
let mut reached_max_tokens = false;
|
||||
let mut tool_uses = Self::stream_completion_with_retries(
|
||||
this.clone(),
|
||||
model.clone(),
|
||||
request,
|
||||
&event_stream,
|
||||
&mut tool_use_limit_reached,
|
||||
&mut refused,
|
||||
&mut reached_max_tokens,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if refused {
|
||||
return Ok(StopReason::Refusal);
|
||||
} else if reached_max_tokens {
|
||||
return Ok(StopReason::MaxTokens);
|
||||
}
|
||||
|
||||
let end_turn = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})?;
|
||||
}
|
||||
|
||||
let mut end_turn = true;
|
||||
this.update(cx, |this, cx| {
|
||||
// Generate title if needed.
|
||||
if this.title.is_none() && update_title.is_none() {
|
||||
update_title = Some(this.update_title(&event_stream, cx));
|
||||
}
|
||||
|
||||
// End the turn if the model didn't use tools.
|
||||
let message = this.pending_message.as_ref();
|
||||
end_turn =
|
||||
message.map_or(true, |message| message.tool_results.is_empty());
|
||||
this.flush_pending_message(cx);
|
||||
})?;
|
||||
|
||||
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
|
||||
if tool_use_limit_reached {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if end_turn {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
return Ok(StopReason::EndTurn);
|
||||
} else {
|
||||
intent = CompletionIntent::ToolResults;
|
||||
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
|
||||
|
||||
if let Some(update_title) = update_title {
|
||||
update_title.await.context("update title failed").log_err();
|
||||
}
|
||||
|
||||
match turn_result {
|
||||
Ok(()) => {
|
||||
log::info!("Turn execution completed");
|
||||
event_stream.send_stop(acp::StopReason::EndTurn);
|
||||
Ok(reason) => {
|
||||
log::info!("Turn execution completed: {:?}", reason);
|
||||
|
||||
if let Some(update_title) = update_title {
|
||||
update_title.await.context("update title failed").log_err();
|
||||
}
|
||||
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
match error.downcast::<CompletionError>() {
|
||||
Ok(CompletionError::Refusal) => {
|
||||
event_stream.send_stop(acp::StopReason::Refusal);
|
||||
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
|
||||
}
|
||||
Ok(CompletionError::MaxTokens) => {
|
||||
event_stream.send_stop(acp::StopReason::MaxTokens);
|
||||
}
|
||||
Ok(CompletionError::Other(error)) | Err(error) => {
|
||||
event_stream.send_error(error);
|
||||
}
|
||||
}
|
||||
event_stream.send_error(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1152,17 +1181,17 @@ impl Thread {
|
||||
Ok(events_rx)
|
||||
}
|
||||
|
||||
async fn stream_completion(
|
||||
this: &WeakEntity<Self>,
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
completion_intent: CompletionIntent,
|
||||
async fn stream_completion_with_retries(
|
||||
this: WeakEntity<Self>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request: LanguageModelRequest,
|
||||
event_stream: &ThreadEventStream,
|
||||
tool_use_limit_reached: &mut bool,
|
||||
refusal: &mut bool,
|
||||
max_tokens_reached: &mut bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
|
||||
log::debug!("Stream completion started successfully");
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})??;
|
||||
|
||||
let mut attempt = None;
|
||||
'retry: loop {
|
||||
@@ -1175,33 +1204,68 @@ impl Thread {
|
||||
attempt
|
||||
);
|
||||
|
||||
log::info!(
|
||||
"Calling model.stream_completion, attempt {}",
|
||||
attempt.unwrap_or(0)
|
||||
);
|
||||
let mut events = model
|
||||
.stream_completion(request.clone(), cx)
|
||||
.await
|
||||
.map_err(|error| anyhow!(error))?;
|
||||
let mut tool_results = FuturesUnordered::new();
|
||||
|
||||
let mut events = model.stream_completion(request.clone(), cx).await?;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
)) => {
|
||||
*tool_use_limit_reached = true;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
)) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_model_request_usage(amount, limit, cx)
|
||||
})?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
|
||||
prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
attempt,
|
||||
input_tokens = usage.input_tokens,
|
||||
output_tokens = usage.output_tokens,
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens,
|
||||
);
|
||||
|
||||
this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
|
||||
*refusal = true;
|
||||
return Ok(FuturesUnordered::default());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
|
||||
*max_tokens_reached = true;
|
||||
return Ok(FuturesUnordered::default());
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(
|
||||
StopReason::ToolUse | StopReason::EndTurn,
|
||||
)) => break,
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
tool_results.extend(this.update(cx, |this, cx| {
|
||||
this.handle_streamed_completion_event(event, event_stream, cx)
|
||||
})??);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
event_stream,
|
||||
cx,
|
||||
));
|
||||
})?;
|
||||
}
|
||||
Err(error) => {
|
||||
let completion_mode =
|
||||
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
|
||||
if completion_mode == CompletionMode::Normal {
|
||||
return Err(anyhow!(error))?;
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let Some(strategy) = Self::retry_strategy_for(&error) else {
|
||||
return Err(anyhow!(error))?;
|
||||
return Err(error.into());
|
||||
};
|
||||
|
||||
let max_attempts = match &strategy {
|
||||
@@ -1215,7 +1279,7 @@ impl Thread {
|
||||
|
||||
let attempt = *attempt;
|
||||
if attempt > max_attempts {
|
||||
return Err(anyhow!(error))?;
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let delay = match &strategy {
|
||||
@@ -1242,29 +1306,7 @@ impl Thread {
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(tool_result) = tool_results.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
return Ok(tool_uses);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1286,14 +1328,14 @@ impl Thread {
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop will
|
||||
/// send back to the model when it resolves.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
event_stream: &ThreadEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<Option<Task<LanguageModelToolResult>>> {
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
@@ -1308,7 +1350,7 @@ impl Thread {
|
||||
}
|
||||
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
||||
ToolUse(tool_use) => {
|
||||
return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
|
||||
return self.handle_tool_use_event(tool_use, event_stream, cx);
|
||||
}
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
@@ -1316,46 +1358,18 @@ impl Thread {
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
return Ok(Some(Task::ready(
|
||||
self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
),
|
||||
return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
)));
|
||||
}
|
||||
UsageUpdate(usage) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = self.id.to_string(),
|
||||
prompt_id = self.prompt_id.to_string(),
|
||||
model = self.model.as_ref().map(|m| m.telemetry_id()),
|
||||
model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
|
||||
input_tokens = usage.input_tokens,
|
||||
output_tokens = usage.output_tokens,
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens,
|
||||
);
|
||||
self.update_token_usage(usage, cx);
|
||||
}
|
||||
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
||||
self.update_model_request_usage(amount, limit, cx);
|
||||
}
|
||||
StatusUpdate(
|
||||
CompletionRequestStatus::Started
|
||||
| CompletionRequestStatus::Queued { .. }
|
||||
| CompletionRequestStatus::Failed { .. },
|
||||
) => {}
|
||||
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
||||
self.tool_use_limit_reached = true;
|
||||
}
|
||||
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||
StatusUpdate(_) => {}
|
||||
UsageUpdate(_) | Stop(_) => unreachable!(),
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_text_event(
|
||||
@@ -2211,8 +2225,25 @@ impl ThreadEventStream {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: acp::StopReason) {
|
||||
self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::MaxTokens => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::Refusal => {
|
||||
self.0
|
||||
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::ToolUse => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_canceled(&self) {
|
||||
|
||||
@@ -2637,7 +2637,7 @@ impl AcpThreadView {
|
||||
}
|
||||
|
||||
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
|
||||
let container = v_flex()
|
||||
let mut container = v_flex()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_error_agent_logo())
|
||||
@@ -2656,9 +2656,7 @@ impl AcpThreadView {
|
||||
),
|
||||
);
|
||||
|
||||
let button = if !self.project.read(cx).is_local() {
|
||||
None
|
||||
} else if let LoadError::Unsupported {
|
||||
if let LoadError::Unsupported {
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
..
|
||||
@@ -2666,7 +2664,7 @@ impl AcpThreadView {
|
||||
{
|
||||
let upgrade_message = upgrade_message.clone();
|
||||
let upgrade_command = upgrade_command.clone();
|
||||
Some(
|
||||
container = container.child(
|
||||
Button::new("upgrade", upgrade_message)
|
||||
.tooltip(Tooltip::text(upgrade_command.clone()))
|
||||
.on_click(cx.listener(move |this, _, window, cx| {
|
||||
@@ -2709,7 +2707,7 @@ impl AcpThreadView {
|
||||
})
|
||||
.detach()
|
||||
})),
|
||||
)
|
||||
);
|
||||
} else if let LoadError::NotInstalled {
|
||||
install_message,
|
||||
install_command,
|
||||
@@ -2718,7 +2716,7 @@ impl AcpThreadView {
|
||||
{
|
||||
let install_message = install_message.clone();
|
||||
let install_command = install_command.clone();
|
||||
Some(
|
||||
container = container.child(
|
||||
Button::new("install", install_message)
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.size(ButtonSize::Medium)
|
||||
@@ -2763,12 +2761,10 @@ impl AcpThreadView {
|
||||
})
|
||||
.detach()
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
);
|
||||
}
|
||||
|
||||
container.children(button).into_any()
|
||||
container.into_any()
|
||||
}
|
||||
|
||||
fn render_activity_bar(
|
||||
|
||||
@@ -154,7 +154,7 @@ impl BufferDiffSnapshot {
|
||||
BufferDiffSnapshot {
|
||||
inner: BufferDiffInner {
|
||||
base_text: language::Buffer::build_empty_snapshot(cx),
|
||||
hunks: SumTree::new(buffer),
|
||||
hunks: SumTree::new(),
|
||||
pending_hunks: SumTree::new(buffer),
|
||||
base_text_exists: false,
|
||||
},
|
||||
|
||||
@@ -96,7 +96,7 @@ impl<T: LocalLanguageToolchainStore> LanguageToolchainStore for T {
|
||||
}
|
||||
|
||||
type DefaultIndex = usize;
|
||||
#[derive(Default, Clone, Debug)]
|
||||
#[derive(Default, Clone)]
|
||||
pub struct ToolchainList {
|
||||
pub toolchains: Vec<Toolchain>,
|
||||
pub default: Option<DefaultIndex>,
|
||||
|
||||
@@ -4643,6 +4643,7 @@ impl LspStore {
|
||||
Some((file, language, raw_buffer.remote_id()))
|
||||
})
|
||||
.sorted_by_key(|(file, _, _)| Reverse(file.worktree.read(cx).is_visible()));
|
||||
|
||||
for (file, language, buffer_id) in buffers {
|
||||
let worktree_id = file.worktree_id(cx);
|
||||
let Some(worktree) = local
|
||||
@@ -4684,6 +4685,7 @@ impl LspStore {
|
||||
cx,
|
||||
)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for node in nodes {
|
||||
let server_id = node.server_id_or_init(|disposition| {
|
||||
let path = &disposition.path;
|
||||
|
||||
@@ -181,7 +181,6 @@ impl LanguageServerTree {
|
||||
&root_path.path,
|
||||
language_name.clone(),
|
||||
);
|
||||
|
||||
(
|
||||
Arc::new(InnerTreeNode::new(
|
||||
adapter.name(),
|
||||
@@ -409,7 +408,6 @@ impl ServerTreeRebase {
|
||||
if live_node.id.get().is_some() {
|
||||
return Some(node);
|
||||
}
|
||||
|
||||
let disposition = &live_node.disposition;
|
||||
let Some((existing_node, _)) = self
|
||||
.old_contents
|
||||
|
||||
@@ -4,7 +4,6 @@ use crate::{
|
||||
Event, git_store::StatusEntry, task_inventory::TaskContexts, task_store::TaskSettingsLocation,
|
||||
*,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use buffer_diff::{
|
||||
BufferDiffEvent, CALCULATE_DIFF_TASK, DiffHunkSecondaryStatus, DiffHunkStatus,
|
||||
DiffHunkStatusKind, assert_hunks,
|
||||
@@ -22,8 +21,7 @@ use http_client::Url;
|
||||
use itertools::Itertools;
|
||||
use language::{
|
||||
Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSourceKind, DiskState, FakeLspAdapter,
|
||||
LanguageConfig, LanguageMatcher, LanguageName, LineEnding, ManifestName, ManifestProvider,
|
||||
ManifestQuery, OffsetRangeExt, Point, ToPoint, ToolchainLister,
|
||||
LanguageConfig, LanguageMatcher, LanguageName, LineEnding, OffsetRangeExt, Point, ToPoint,
|
||||
language_settings::{AllLanguageSettings, LanguageSettingsContent, language_settings},
|
||||
tree_sitter_rust, tree_sitter_typescript,
|
||||
};
|
||||
@@ -598,203 +596,6 @@ async fn test_fallback_to_single_worktree_tasks(cx: &mut gpui::TestAppContext) {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
|
||||
cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
pub(crate) struct PyprojectTomlManifestProvider;
|
||||
|
||||
impl ManifestProvider for PyprojectTomlManifestProvider {
|
||||
fn name(&self) -> ManifestName {
|
||||
SharedString::new_static("pyproject.toml").into()
|
||||
}
|
||||
|
||||
fn search(
|
||||
&self,
|
||||
ManifestQuery {
|
||||
path,
|
||||
depth,
|
||||
delegate,
|
||||
}: ManifestQuery,
|
||||
) -> Option<Arc<Path>> {
|
||||
for path in path.ancestors().take(depth) {
|
||||
let p = path.join("pyproject.toml");
|
||||
if delegate.exists(&p, Some(false)) {
|
||||
return Some(path.into());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/the-root"),
|
||||
json!({
|
||||
".zed": {
|
||||
"settings.json": r#"
|
||||
{
|
||||
"languages": {
|
||||
"Python": {
|
||||
"language_servers": ["ty"]
|
||||
}
|
||||
}
|
||||
}"#
|
||||
},
|
||||
"project-a": {
|
||||
".venv": {},
|
||||
"file.py": "",
|
||||
"pyproject.toml": ""
|
||||
},
|
||||
"project-b": {
|
||||
".venv": {},
|
||||
"source_file.py":"",
|
||||
"another_file.py": "",
|
||||
"pyproject.toml": ""
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
cx.update(|cx| {
|
||||
ManifestProvidersStore::global(cx).register(Arc::new(PyprojectTomlManifestProvider))
|
||||
});
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/the-root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let _fake_python_server = language_registry.register_fake_lsp(
|
||||
"Python",
|
||||
FakeLspAdapter {
|
||||
name: "ty",
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
language_registry.add(python_lang(fs.clone()));
|
||||
let (first_buffer, _handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/the-root/project-a/file.py"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx.executor().run_until_parked();
|
||||
let servers = project.update(cx, |project, cx| {
|
||||
project.lsp_store.update(cx, |this, cx| {
|
||||
first_buffer.update(cx, |buffer, cx| {
|
||||
this.language_servers_for_local_buffer(buffer, cx)
|
||||
.map(|(adapter, server)| (adapter.clone(), server.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
assert_eq!(servers.len(), 1);
|
||||
let (adapter, server) = servers.into_iter().next().unwrap();
|
||||
assert_eq!(adapter.name(), LanguageServerName::new_static("ty"));
|
||||
assert_eq!(server.server_id(), LanguageServerId(0));
|
||||
// `workspace_folders` are set to the rooting point.
|
||||
assert_eq!(
|
||||
server.workspace_folders(),
|
||||
BTreeSet::from_iter(
|
||||
[Url::from_file_path(path!("/the-root/project-a")).unwrap()].into_iter()
|
||||
)
|
||||
);
|
||||
|
||||
let (second_project_buffer, _other_handle) = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer_with_lsp(path!("/the-root/project-b/source_file.py"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx.executor().run_until_parked();
|
||||
let servers = project.update(cx, |project, cx| {
|
||||
project.lsp_store.update(cx, |this, cx| {
|
||||
second_project_buffer.update(cx, |buffer, cx| {
|
||||
this.language_servers_for_local_buffer(buffer, cx)
|
||||
.map(|(adapter, server)| (adapter.clone(), server.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
assert_eq!(servers.len(), 1);
|
||||
let (adapter, server) = servers.into_iter().next().unwrap();
|
||||
assert_eq!(adapter.name(), LanguageServerName::new_static("ty"));
|
||||
// We're not using venvs at all here, so both folders should fall under the same root.
|
||||
assert_eq!(server.server_id(), LanguageServerId(0));
|
||||
// Now, let's select a different toolchain for one of subprojects.
|
||||
let (available_toolchains_for_b, root_path) = project
|
||||
.update(cx, |this, cx| {
|
||||
let worktree_id = this.worktrees(cx).next().unwrap().read(cx).id();
|
||||
this.available_toolchains(
|
||||
ProjectPath {
|
||||
worktree_id,
|
||||
path: Arc::from("project-b/source_file.py".as_ref()),
|
||||
},
|
||||
LanguageName::new("Python"),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.expect("A toolchain to be discovered");
|
||||
assert_eq!(root_path.as_ref(), Path::new("project-b"));
|
||||
assert_eq!(available_toolchains_for_b.toolchains().len(), 1);
|
||||
let currently_active_toolchain = project
|
||||
.update(cx, |this, cx| {
|
||||
let worktree_id = this.worktrees(cx).next().unwrap().read(cx).id();
|
||||
this.active_toolchain(
|
||||
ProjectPath {
|
||||
worktree_id,
|
||||
path: Arc::from("project-b/source_file.py".as_ref()),
|
||||
},
|
||||
LanguageName::new("Python"),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(currently_active_toolchain.is_none());
|
||||
let _ = project
|
||||
.update(cx, |this, cx| {
|
||||
let worktree_id = this.worktrees(cx).next().unwrap().read(cx).id();
|
||||
this.activate_toolchain(
|
||||
ProjectPath {
|
||||
worktree_id,
|
||||
path: root_path,
|
||||
},
|
||||
available_toolchains_for_b
|
||||
.toolchains
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
let servers = project.update(cx, |project, cx| {
|
||||
project.lsp_store.update(cx, |this, cx| {
|
||||
second_project_buffer.update(cx, |buffer, cx| {
|
||||
this.language_servers_for_local_buffer(buffer, cx)
|
||||
.map(|(adapter, server)| (adapter.clone(), server.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
assert_eq!(servers.len(), 1);
|
||||
let (adapter, server) = servers.into_iter().next().unwrap();
|
||||
assert_eq!(adapter.name(), LanguageServerName::new_static("ty"));
|
||||
// There's a new language server in town.
|
||||
assert_eq!(server.server_id(), LanguageServerId(1));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_managing_language_servers(cx: &mut gpui::TestAppContext) {
|
||||
init_test(cx);
|
||||
@@ -9181,65 +8982,6 @@ fn rust_lang() -> Arc<Language> {
|
||||
))
|
||||
}
|
||||
|
||||
fn python_lang(fs: Arc<FakeFs>) -> Arc<Language> {
|
||||
struct PythonMootToolchainLister(Arc<FakeFs>);
|
||||
#[async_trait]
|
||||
impl ToolchainLister for PythonMootToolchainLister {
|
||||
async fn list(
|
||||
&self,
|
||||
worktree_root: PathBuf,
|
||||
subroot_relative_path: Option<Arc<Path>>,
|
||||
_: Option<HashMap<String, String>>,
|
||||
) -> ToolchainList {
|
||||
// This lister will always return a path .venv directories within ancestors
|
||||
let ancestors = subroot_relative_path
|
||||
.into_iter()
|
||||
.flat_map(|path| path.ancestors().map(ToOwned::to_owned).collect::<Vec<_>>());
|
||||
let mut toolchains = vec![];
|
||||
for ancestor in ancestors {
|
||||
let venv_path = worktree_root.join(ancestor).join(".venv");
|
||||
if self.0.is_dir(&venv_path).await {
|
||||
toolchains.push(Toolchain {
|
||||
name: SharedString::new("Python Venv"),
|
||||
path: venv_path.to_string_lossy().into_owned().into(),
|
||||
language_name: LanguageName(SharedString::new_static("Python")),
|
||||
as_json: serde_json::Value::Null,
|
||||
})
|
||||
}
|
||||
}
|
||||
ToolchainList {
|
||||
toolchains,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
// Returns a term which we should use in UI to refer to a toolchain.
|
||||
fn term(&self) -> SharedString {
|
||||
SharedString::new_static("virtual environment")
|
||||
}
|
||||
/// Returns the name of the manifest file for this toolchain.
|
||||
fn manifest_name(&self) -> ManifestName {
|
||||
SharedString::new_static("pyproject.toml").into()
|
||||
}
|
||||
}
|
||||
Arc::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Python".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["py".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
None, // We're not testing Python parsing with this language.
|
||||
)
|
||||
.with_manifest(Some(ManifestName::from(SharedString::new_static(
|
||||
"pyproject.toml",
|
||||
))))
|
||||
.with_toolchain_lister(Some(Arc::new(PythonMootToolchainLister(fs)))),
|
||||
)
|
||||
}
|
||||
|
||||
fn typescript_lang() -> Arc<Language> {
|
||||
Arc::new(Language::new(
|
||||
LanguageConfig {
|
||||
|
||||
@@ -205,13 +205,13 @@ where
|
||||
|
||||
#[track_caller]
|
||||
pub fn prev(&mut self) {
|
||||
self.search_backward(|_| true)
|
||||
self.search_backward(|_| Ordering::Equal)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn search_backward<F>(&mut self, mut filter_node: F)
|
||||
where
|
||||
F: FnMut(&T::Summary) -> bool,
|
||||
F: FnMut(&T::Summary) -> Ordering,
|
||||
{
|
||||
if !self.did_seek {
|
||||
self.did_seek = true;
|
||||
@@ -222,10 +222,15 @@ where
|
||||
self.position = D::zero(self.cx);
|
||||
self.at_end = self.tree.is_empty();
|
||||
if !self.tree.is_empty() {
|
||||
let position = if let Some(summary) = self.tree.0.summary() {
|
||||
D::from_summary(summary, self.cx)
|
||||
} else {
|
||||
D::zero(self.cx)
|
||||
};
|
||||
self.stack.push(StackEntry {
|
||||
tree: self.tree,
|
||||
index: self.tree.0.child_summaries().len(),
|
||||
position: D::from_summary(self.tree.summary(), self.cx),
|
||||
position,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -248,12 +253,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
for summary in &entry.tree.0.child_summaries()[..entry.index] {
|
||||
self.position.add_summary(summary, self.cx);
|
||||
if entry.index != 0 {
|
||||
self.position
|
||||
.add_summary(&entry.tree.0.child_summaries()[entry.index - 1], self.cx);
|
||||
}
|
||||
|
||||
entry.position = self.position.clone();
|
||||
|
||||
descending = filter_node(&entry.tree.0.child_summaries()[entry.index]);
|
||||
descending = filter_node(&entry.tree.0.child_summaries()[entry.index]).is_ge();
|
||||
match entry.tree.0.as_ref() {
|
||||
Node::Internal { child_trees, .. } => {
|
||||
if descending {
|
||||
@@ -276,13 +283,13 @@ where
|
||||
|
||||
#[track_caller]
|
||||
pub fn next(&mut self) {
|
||||
self.search_forward(|_| true)
|
||||
self.search_forward(|_| Ordering::Equal)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn search_forward<F>(&mut self, mut filter_node: F)
|
||||
where
|
||||
F: FnMut(&T::Summary) -> bool,
|
||||
F: FnMut(&T::Summary) -> Ordering,
|
||||
{
|
||||
let mut descend = false;
|
||||
|
||||
@@ -312,14 +319,15 @@ where
|
||||
entry.position = self.position.clone();
|
||||
}
|
||||
|
||||
while entry.index < child_summaries.len() {
|
||||
let next_summary = &child_summaries[entry.index];
|
||||
if filter_node(next_summary) {
|
||||
break;
|
||||
} else {
|
||||
entry.index += 1;
|
||||
entry.position.add_summary(next_summary, self.cx);
|
||||
self.position.add_summary(next_summary, self.cx);
|
||||
if entry.index < child_summaries.len() {
|
||||
let index = child_summaries[entry.index..]
|
||||
.partition_point(|item| filter_node(item).is_lt());
|
||||
|
||||
entry.index += index;
|
||||
|
||||
if let Some(summary) = child_summaries.get(entry.index) {
|
||||
entry.position.add_summary(summary, self.cx);
|
||||
self.position.add_summary(summary, self.cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,18 +341,19 @@ where
|
||||
self.position.add_summary(item_summary, self.cx);
|
||||
}
|
||||
|
||||
loop {
|
||||
if let Some(next_item_summary) = item_summaries.get(entry.index) {
|
||||
if filter_node(next_item_summary) {
|
||||
return;
|
||||
} else {
|
||||
entry.index += 1;
|
||||
entry.position.add_summary(next_item_summary, self.cx);
|
||||
self.position.add_summary(next_item_summary, self.cx);
|
||||
}
|
||||
} else {
|
||||
break None;
|
||||
if entry.index < item_summaries.len() {
|
||||
let index = item_summaries[entry.index..]
|
||||
.partition_point(|item| filter_node(item).is_lt());
|
||||
|
||||
entry.index += index;
|
||||
|
||||
if let Some(summary) = item_summaries.get(entry.index) {
|
||||
entry.position.add_summary(summary, self.cx);
|
||||
self.position.add_summary(summary, self.cx);
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -405,7 +414,7 @@ where
|
||||
Target: SeekTarget<'a, T::Summary, D>,
|
||||
{
|
||||
let mut slice = SliceSeekAggregate {
|
||||
tree: SumTree::new(self.cx),
|
||||
tree: SumTree::new(),
|
||||
leaf_items: ArrayVec::new(),
|
||||
leaf_item_summaries: ArrayVec::new(),
|
||||
leaf_summary: <T::Summary as Summary>::zero(self.cx),
|
||||
@@ -632,7 +641,7 @@ pub struct FilterCursor<'a, F, T: Item, D> {
|
||||
|
||||
impl<'a, F, T: Item, D> FilterCursor<'a, F, T, D>
|
||||
where
|
||||
F: FnMut(&T::Summary) -> bool,
|
||||
F: FnMut(&T::Summary) -> Ordering,
|
||||
T: Item,
|
||||
D: Dimension<'a, T::Summary>,
|
||||
{
|
||||
@@ -675,7 +684,7 @@ where
|
||||
|
||||
impl<'a, F, T: Item, U> Iterator for FilterCursor<'a, F, T, U>
|
||||
where
|
||||
F: FnMut(&T::Summary) -> bool,
|
||||
F: FnMut(&T::Summary) -> Ordering,
|
||||
U: Dimension<'a, T::Summary>,
|
||||
{
|
||||
type Item = &'a T;
|
||||
@@ -732,7 +741,6 @@ impl<T: Item> SeekAggregate<'_, T> for SliceSeekAggregate<T> {
|
||||
fn end_leaf(&mut self, cx: &<T::Summary as Summary>::Context) {
|
||||
self.tree.append(
|
||||
SumTree(Arc::new(Node::Leaf {
|
||||
summary: mem::replace(&mut self.leaf_summary, <T::Summary as Summary>::zero(cx)),
|
||||
items: mem::take(&mut self.leaf_items),
|
||||
item_summaries: mem::take(&mut self.leaf_item_summaries),
|
||||
})),
|
||||
|
||||
@@ -4,6 +4,8 @@ mod tree_map;
|
||||
use arrayvec::ArrayVec;
|
||||
pub use cursor::{Cursor, FilterCursor, Iter};
|
||||
use rayon::prelude::*;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
|
||||
@@ -39,6 +41,7 @@ pub trait Summary: Clone {
|
||||
|
||||
fn zero(cx: &Self::Context) -> Self;
|
||||
fn add_summary(&mut self, summary: &Self, cx: &Self::Context);
|
||||
fn sub_summary(&mut self, _: &Self, _: &Self::Context) {}
|
||||
}
|
||||
|
||||
/// Catch-all implementation for when you need something that implements [`Summary`] without a specific type.
|
||||
@@ -187,25 +190,15 @@ where
|
||||
}
|
||||
|
||||
impl<T: Item> SumTree<T> {
|
||||
pub fn new(cx: &<T::Summary as Summary>::Context) -> Self {
|
||||
pub fn new() -> Self {
|
||||
SumTree(Arc::new(Node::Leaf {
|
||||
summary: <T::Summary as Summary>::zero(cx),
|
||||
items: ArrayVec::new(),
|
||||
item_summaries: ArrayVec::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Useful in cases where the item type has a non-trivial context type, but the zero value of the summary type doesn't depend on that context.
|
||||
pub fn from_summary(summary: T::Summary) -> Self {
|
||||
SumTree(Arc::new(Node::Leaf {
|
||||
summary,
|
||||
items: ArrayVec::new(),
|
||||
item_summaries: ArrayVec::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn from_item(item: T, cx: &<T::Summary as Summary>::Context) -> Self {
|
||||
let mut tree = Self::new(cx);
|
||||
let mut tree = Self::new();
|
||||
tree.push(item, cx);
|
||||
tree
|
||||
}
|
||||
@@ -219,16 +212,21 @@ impl<T: Item> SumTree<T> {
|
||||
let mut iter = iter.into_iter().fuse().peekable();
|
||||
while iter.peek().is_some() {
|
||||
let items: ArrayVec<T, { 2 * TREE_BASE }> = iter.by_ref().take(2 * TREE_BASE).collect();
|
||||
let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
|
||||
items.iter().map(|item| item.summary(cx)).collect();
|
||||
|
||||
let mut summary = item_summaries[0].clone();
|
||||
for item_summary in &item_summaries[1..] {
|
||||
<T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
|
||||
}
|
||||
|
||||
let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = items
|
||||
.iter()
|
||||
.scan(None, |previous_item, item| {
|
||||
let summary = item.summary(cx);
|
||||
let current_item = if let Some(mut base) = previous_item.take() {
|
||||
<T::Summary as Summary>::add_summary(&mut base, &summary, cx);
|
||||
base
|
||||
} else {
|
||||
summary
|
||||
};
|
||||
_ = previous_item.insert(current_item.clone());
|
||||
Some(current_item)
|
||||
})
|
||||
.collect();
|
||||
nodes.push(Node::Leaf {
|
||||
summary,
|
||||
items,
|
||||
item_summaries,
|
||||
});
|
||||
@@ -241,13 +239,11 @@ impl<T: Item> SumTree<T> {
|
||||
let mut current_parent_node = None;
|
||||
for child_node in nodes.drain(..) {
|
||||
let parent_node = current_parent_node.get_or_insert_with(|| Node::Internal {
|
||||
summary: <T::Summary as Summary>::zero(cx),
|
||||
height,
|
||||
child_summaries: ArrayVec::new(),
|
||||
child_trees: ArrayVec::new(),
|
||||
});
|
||||
let Node::Internal {
|
||||
summary,
|
||||
child_summaries,
|
||||
child_trees,
|
||||
..
|
||||
@@ -255,9 +251,15 @@ impl<T: Item> SumTree<T> {
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
let child_summary = child_node.summary();
|
||||
<T::Summary as Summary>::add_summary(summary, child_summary, cx);
|
||||
child_summaries.push(child_summary.clone());
|
||||
let child_summary = child_node.summary_or_zero(cx);
|
||||
let child_summary = if let Some(mut last) = child_summaries.last().cloned() {
|
||||
<T::Summary as Summary>::add_summary(&mut last, &child_summary, cx);
|
||||
last
|
||||
} else {
|
||||
child_summary.into_owned()
|
||||
};
|
||||
|
||||
child_summaries.push(child_summary);
|
||||
child_trees.push(Self(Arc::new(child_node)));
|
||||
|
||||
if child_trees.len() == 2 * TREE_BASE {
|
||||
@@ -269,7 +271,7 @@ impl<T: Item> SumTree<T> {
|
||||
}
|
||||
|
||||
if nodes.is_empty() {
|
||||
Self::new(cx)
|
||||
Self::new()
|
||||
} else {
|
||||
debug_assert_eq!(nodes.len(), 1);
|
||||
Self(Arc::new(nodes.pop().unwrap()))
|
||||
@@ -289,14 +291,22 @@ impl<T: Item> SumTree<T> {
|
||||
.chunks(2 * TREE_BASE)
|
||||
.map(|items| {
|
||||
let items: ArrayVec<T, { 2 * TREE_BASE }> = items.into_iter().collect();
|
||||
let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> =
|
||||
items.iter().map(|item| item.summary(cx)).collect();
|
||||
let mut summary = item_summaries[0].clone();
|
||||
for item_summary in &item_summaries[1..] {
|
||||
<T::Summary as Summary>::add_summary(&mut summary, item_summary, cx);
|
||||
}
|
||||
let item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = items
|
||||
.iter()
|
||||
.scan(None, |previous_item, item| {
|
||||
let summary = item.summary(cx);
|
||||
let current_item = if let Some(mut base) = previous_item.take() {
|
||||
<T::Summary as Summary>::add_summary(&mut base, &summary, cx);
|
||||
base
|
||||
} else {
|
||||
summary
|
||||
};
|
||||
_ = previous_item.insert(current_item.clone());
|
||||
Some(current_item)
|
||||
})
|
||||
.collect();
|
||||
|
||||
SumTree(Arc::new(Node::Leaf {
|
||||
summary,
|
||||
items,
|
||||
item_summaries,
|
||||
}))
|
||||
@@ -314,7 +324,7 @@ impl<T: Item> SumTree<T> {
|
||||
child_nodes.into_iter().collect();
|
||||
let child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }> = child_trees
|
||||
.iter()
|
||||
.map(|child_tree| child_tree.summary().clone())
|
||||
.map(|child_tree| child_tree.0.summary_or_zero(cx).into_owned())
|
||||
.collect();
|
||||
let mut summary = child_summaries[0].clone();
|
||||
for child_summary in &child_summaries[1..] {
|
||||
@@ -322,7 +332,7 @@ impl<T: Item> SumTree<T> {
|
||||
}
|
||||
SumTree(Arc::new(Node::Internal {
|
||||
height,
|
||||
summary,
|
||||
|
||||
child_summaries,
|
||||
child_trees,
|
||||
}))
|
||||
@@ -331,7 +341,7 @@ impl<T: Item> SumTree<T> {
|
||||
}
|
||||
|
||||
if nodes.is_empty() {
|
||||
Self::new(cx)
|
||||
Self::new()
|
||||
} else {
|
||||
debug_assert_eq!(nodes.len(), 1);
|
||||
nodes.pop().unwrap()
|
||||
@@ -369,7 +379,7 @@ impl<T: Item> SumTree<T> {
|
||||
filter_node: F,
|
||||
) -> FilterCursor<'a, F, T, U>
|
||||
where
|
||||
F: FnMut(&T::Summary) -> bool,
|
||||
F: FnMut(&T::Summary) -> Ordering,
|
||||
U: Dimension<'a, T::Summary>,
|
||||
{
|
||||
FilterCursor::new(self, cx, filter_node)
|
||||
@@ -395,28 +405,57 @@ impl<T: Item> SumTree<T> {
|
||||
) -> Option<T::Summary> {
|
||||
match Arc::make_mut(&mut self.0) {
|
||||
Node::Internal {
|
||||
summary,
|
||||
child_summaries,
|
||||
child_trees,
|
||||
..
|
||||
} => {
|
||||
let last_summary = child_summaries.last_mut().unwrap();
|
||||
let last_child = child_trees.last_mut().unwrap();
|
||||
*last_summary = last_child.update_last_recursive(f, cx).unwrap();
|
||||
*summary = sum(child_summaries.iter(), cx);
|
||||
Some(summary.clone())
|
||||
|
||||
let mut bare_summary = last_child.update_last_recursive(f, cx).unwrap();
|
||||
|
||||
if let Some(mut second_to_last_summary) = child_summaries
|
||||
.len()
|
||||
.checked_sub(2)
|
||||
.and_then(|ix| child_summaries.get(ix))
|
||||
.cloned()
|
||||
{
|
||||
<T::Summary as Summary>::add_summary(
|
||||
&mut second_to_last_summary,
|
||||
&bare_summary,
|
||||
cx,
|
||||
);
|
||||
bare_summary = second_to_last_summary;
|
||||
}
|
||||
let last_summary = child_summaries.last_mut().unwrap();
|
||||
*last_summary = bare_summary;
|
||||
|
||||
Some(last_summary.clone())
|
||||
}
|
||||
Node::Leaf {
|
||||
summary,
|
||||
items,
|
||||
item_summaries,
|
||||
} => {
|
||||
let preceding_summary = item_summaries
|
||||
.len()
|
||||
.checked_sub(2)
|
||||
.and_then(|ix| item_summaries.get(ix))
|
||||
.cloned();
|
||||
if let Some((item, item_summary)) = items.last_mut().zip(item_summaries.last_mut())
|
||||
{
|
||||
(f)(item);
|
||||
*item_summary = item.summary(cx);
|
||||
*summary = sum(item_summaries.iter(), cx);
|
||||
Some(summary.clone())
|
||||
let mut bare_summary = item.summary(cx);
|
||||
|
||||
if let Some(mut second_to_last_summary) = preceding_summary {
|
||||
<T::Summary as Summary>::add_summary(
|
||||
&mut second_to_last_summary,
|
||||
&bare_summary,
|
||||
cx,
|
||||
);
|
||||
bare_summary = second_to_last_summary;
|
||||
}
|
||||
*item_summary = bare_summary.clone();
|
||||
|
||||
Some(item_summary.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -429,19 +468,11 @@ impl<T: Item> SumTree<T> {
|
||||
cx: &<T::Summary as Summary>::Context,
|
||||
) -> D {
|
||||
let mut extent = D::zero(cx);
|
||||
match self.0.as_ref() {
|
||||
Node::Internal { summary, .. } | Node::Leaf { summary, .. } => {
|
||||
extent.add_summary(summary, cx);
|
||||
}
|
||||
if let Some(last) = self.0.child_summaries().last() {
|
||||
extent.add_summary(last, cx);
|
||||
}
|
||||
extent
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> &T::Summary {
|
||||
match self.0.as_ref() {
|
||||
Node::Internal { summary, .. } => summary,
|
||||
Node::Leaf { summary, .. } => summary,
|
||||
}
|
||||
extent
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
@@ -473,7 +504,6 @@ impl<T: Item> SumTree<T> {
|
||||
let summary = item.summary(cx);
|
||||
self.append(
|
||||
SumTree(Arc::new(Node::Leaf {
|
||||
summary: summary.clone(),
|
||||
items: ArrayVec::from_iter(Some(item)),
|
||||
item_summaries: ArrayVec::from_iter(Some(summary)),
|
||||
})),
|
||||
@@ -503,13 +533,11 @@ impl<T: Item> SumTree<T> {
|
||||
match Arc::make_mut(&mut self.0) {
|
||||
Node::Internal {
|
||||
height,
|
||||
summary,
|
||||
child_summaries,
|
||||
child_trees,
|
||||
..
|
||||
} => {
|
||||
let other_node = other.0.clone();
|
||||
<T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
|
||||
|
||||
let height_delta = *height - other_node.height();
|
||||
let mut summaries_to_append = ArrayVec::<T::Summary, { 2 * TREE_BASE }>::new();
|
||||
@@ -518,18 +546,22 @@ impl<T: Item> SumTree<T> {
|
||||
summaries_to_append.extend(other_node.child_summaries().iter().cloned());
|
||||
trees_to_append.extend(other_node.child_trees().iter().cloned());
|
||||
} else if height_delta == 1 && !other_node.is_underflowing() {
|
||||
summaries_to_append.push(other_node.summary().clone());
|
||||
summaries_to_append.push(other_node.summary_or_zero(cx).into_owned());
|
||||
trees_to_append.push(other)
|
||||
} else {
|
||||
let tree_to_append = child_trees
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.push_tree_recursive(other, cx);
|
||||
*child_summaries.last_mut().unwrap() =
|
||||
child_trees.last().unwrap().0.summary().clone();
|
||||
*child_summaries.last_mut().unwrap() = child_trees
|
||||
.last()
|
||||
.unwrap()
|
||||
.0
|
||||
.summary_or_zero(cx)
|
||||
.into_owned();
|
||||
|
||||
if let Some(split_tree) = tree_to_append {
|
||||
summaries_to_append.push(split_tree.0.summary().clone());
|
||||
summaries_to_append.push(split_tree.0.summary_or_zero(cx).into_owned());
|
||||
trees_to_append.push(split_tree);
|
||||
}
|
||||
}
|
||||
@@ -554,13 +586,12 @@ impl<T: Item> SumTree<T> {
|
||||
left_trees = all_trees.by_ref().take(midpoint).collect();
|
||||
right_trees = all_trees.collect();
|
||||
}
|
||||
*summary = sum(left_summaries.iter(), cx);
|
||||
|
||||
*child_summaries = left_summaries;
|
||||
*child_trees = left_trees;
|
||||
|
||||
Some(SumTree(Arc::new(Node::Internal {
|
||||
height: *height,
|
||||
summary: sum(right_summaries.iter(), cx),
|
||||
child_summaries: right_summaries,
|
||||
child_trees: right_trees,
|
||||
})))
|
||||
@@ -571,7 +602,6 @@ impl<T: Item> SumTree<T> {
|
||||
}
|
||||
}
|
||||
Node::Leaf {
|
||||
summary,
|
||||
items,
|
||||
item_summaries,
|
||||
} => {
|
||||
@@ -599,16 +629,24 @@ impl<T: Item> SumTree<T> {
|
||||
}
|
||||
*items = left_items;
|
||||
*item_summaries = left_summaries;
|
||||
*summary = sum(item_summaries.iter(), cx);
|
||||
|
||||
Some(SumTree(Arc::new(Node::Leaf {
|
||||
items: right_items,
|
||||
summary: sum(right_summaries.iter(), cx),
|
||||
|
||||
item_summaries: right_summaries,
|
||||
})))
|
||||
} else {
|
||||
<T::Summary as Summary>::add_summary(summary, other_node.summary(), cx);
|
||||
let baseline = item_summaries.last().cloned();
|
||||
|
||||
items.extend(other_node.items().iter().cloned());
|
||||
item_summaries.extend(other_node.child_summaries().iter().cloned());
|
||||
item_summaries.extend(other_node.child_summaries().iter().map(|summary| {
|
||||
if let Some(mut baseline) = baseline.clone() {
|
||||
<T::Summary as Summary>::add_summary(&mut baseline, summary, cx);
|
||||
baseline
|
||||
} else {
|
||||
summary.clone()
|
||||
}
|
||||
}));
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -622,14 +660,14 @@ impl<T: Item> SumTree<T> {
|
||||
) -> Self {
|
||||
let height = left.0.height() + 1;
|
||||
let mut child_summaries = ArrayVec::new();
|
||||
child_summaries.push(left.0.summary().clone());
|
||||
child_summaries.push(right.0.summary().clone());
|
||||
child_summaries.push(left.0.summary_or_zero(cx).into_owned());
|
||||
child_summaries.push(right.0.summary_or_zero(cx).into_owned());
|
||||
let mut child_trees = ArrayVec::new();
|
||||
child_trees.push(left);
|
||||
child_trees.push(right);
|
||||
SumTree(Arc::new(Node::Internal {
|
||||
height,
|
||||
summary: sum(child_summaries.iter(), cx),
|
||||
|
||||
child_summaries,
|
||||
child_trees,
|
||||
}))
|
||||
@@ -716,7 +754,7 @@ impl<T: KeyedItem> SumTree<T> {
|
||||
|
||||
*self = {
|
||||
let mut cursor = self.cursor::<T::Key>(cx);
|
||||
let mut new_tree = SumTree::new(cx);
|
||||
let mut new_tree = SumTree::new();
|
||||
let mut buffered_items = Vec::new();
|
||||
|
||||
cursor.seek(&T::Key::zero(cx), Bias::Left);
|
||||
@@ -771,13 +809,9 @@ impl<T: KeyedItem> SumTree<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> Default for SumTree<T>
|
||||
where
|
||||
T: Item<Summary = S>,
|
||||
S: Summary<Context = ()>,
|
||||
{
|
||||
impl<T: Item> Default for SumTree<T> {
|
||||
fn default() -> Self {
|
||||
Self::new(&())
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -785,12 +819,10 @@ where
|
||||
pub enum Node<T: Item> {
|
||||
Internal {
|
||||
height: u8,
|
||||
summary: T::Summary,
|
||||
child_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
|
||||
child_trees: ArrayVec<SumTree<T>, { 2 * TREE_BASE }>,
|
||||
},
|
||||
Leaf {
|
||||
summary: T::Summary,
|
||||
items: ArrayVec<T, { 2 * TREE_BASE }>,
|
||||
item_summaries: ArrayVec<T::Summary, { 2 * TREE_BASE }>,
|
||||
},
|
||||
@@ -805,23 +837,19 @@ where
|
||||
match self {
|
||||
Node::Internal {
|
||||
height,
|
||||
summary,
|
||||
child_summaries,
|
||||
child_trees,
|
||||
} => f
|
||||
.debug_struct("Internal")
|
||||
.field("height", height)
|
||||
.field("summary", summary)
|
||||
.field("child_summaries", child_summaries)
|
||||
.field("child_trees", child_trees)
|
||||
.finish(),
|
||||
Node::Leaf {
|
||||
summary,
|
||||
items,
|
||||
item_summaries,
|
||||
} => f
|
||||
.debug_struct("Leaf")
|
||||
.field("summary", summary)
|
||||
.field("items", items)
|
||||
.field("item_summaries", item_summaries)
|
||||
.finish(),
|
||||
@@ -841,11 +869,16 @@ impl<T: Item> Node<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn summary(&self) -> &T::Summary {
|
||||
match self {
|
||||
Node::Internal { summary, .. } => summary,
|
||||
Node::Leaf { summary, .. } => summary,
|
||||
}
|
||||
fn summary<'a>(&'a self) -> Option<&'a T::Summary> {
|
||||
let child_summaries = self.child_summaries();
|
||||
child_summaries.last()
|
||||
}
|
||||
|
||||
fn summary_or_zero<'a>(&'a self, cx: &<T::Summary as Summary>::Context) -> Cow<'a, T::Summary> {
|
||||
self.summary().map_or_else(
|
||||
|| Cow::Owned(<T::Summary as Summary>::zero(cx)),
|
||||
|last_summary| Cow::Borrowed(last_summary),
|
||||
)
|
||||
}
|
||||
|
||||
fn child_summaries(&self) -> &[T::Summary] {
|
||||
@@ -894,18 +927,6 @@ impl<T: KeyedItem> Edit<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn sum<'a, T, I>(iter: I, cx: &T::Context) -> T
|
||||
where
|
||||
T: 'a + Summary,
|
||||
I: Iterator<Item = &'a T>,
|
||||
{
|
||||
let mut sum = T::zero(cx);
|
||||
for value in iter {
|
||||
sum.add_summary(value, cx);
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -997,7 +1018,7 @@ mod tests {
|
||||
log::info!("tree items: {:?}", tree.items(&()));
|
||||
|
||||
let mut filter_cursor =
|
||||
tree.filter::<_, Count>(&(), |summary| summary.contains_even);
|
||||
tree.filter::<_, Count>(&(), |summary| summary.contains_even.cmp(&false));
|
||||
let expected_filtered_items = tree
|
||||
.items(&())
|
||||
.into_iter()
|
||||
@@ -1096,7 +1117,7 @@ mod tests {
|
||||
cursor.seek(&Count(start), start_bias);
|
||||
let summary = cursor.summary::<_, Sum>(&Count(end), end_bias);
|
||||
|
||||
assert_eq!(summary.0, slice.summary().sum);
|
||||
assert_eq!(summary.0, slice.0.summary_or_zero(&()).sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1172,9 +1193,11 @@ mod tests {
|
||||
// Multiple-element tree
|
||||
let mut tree = SumTree::default();
|
||||
tree.extend(vec![1, 2, 3, 4, 5, 6], &());
|
||||
|
||||
let mut cursor = tree.cursor::<IntegersSummary>(&());
|
||||
|
||||
assert_eq!(cursor.slice(&Count(2), Bias::Right).items(&()), [1, 2]);
|
||||
let slice = cursor.slice(&Count(2), Bias::Right);
|
||||
assert_eq!(slice.items(&()), [1, 2]);
|
||||
assert_eq!(cursor.item(), Some(&3));
|
||||
assert_eq!(cursor.prev_item(), Some(&2));
|
||||
assert_eq!(cursor.next_item(), Some(&4));
|
||||
|
||||
@@ -407,12 +407,12 @@ mod tests {
|
||||
map.insert("baaab", 4);
|
||||
map.insert("c", 5);
|
||||
|
||||
let result = map
|
||||
.iter_from(&"ba")
|
||||
let items = map.iter_from(&"ba");
|
||||
let result = items
|
||||
.take_while(|(key, _)| key.starts_with("ba"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result.len(), 2, "{result:?}");
|
||||
assert!(result.iter().any(|(k, _)| k == &&"baa"));
|
||||
assert!(result.iter().any(|(k, _)| k == &&"baaab"));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user