Compare commits

...

21 Commits

Author SHA1 Message Date
Nathan Sobo
f2409f2605 Run cargo fmt 2025-12-14 12:21:46 -07:00
Nathan Sobo
ce1c228e6e Rename TestAppWindow to TestWindow, internal TestWindow to TestPlatformWindow
- Public API: TestWindow<V> - the new typed test window wrapper
- Internal: TestPlatformWindow - the platform-level mock window (pub(crate))
2025-12-14 10:48:11 -07:00
Nathan Sobo
96ddbd4e13 Add TestApp and TestAppWindow for cleaner GPUI testing
TestApp provides a simpler alternative to TestAppContext with:
- Automatic effect flushing after updates
- Clean window creation returning typed TestAppWindow<V>
- Scene inspection via SceneSnapshot
- Input simulation helpers

Also adds:
- Background::as_solid() helper in color.rs
- SceneSnapshot for inspecting rendered quads/glyphs in scene.rs
2025-12-14 10:43:35 -07:00
Nathan Sobo
f224d2a923 Add TestApp and TestAppWindow for cleaner GPUI testing
Adds zed/crates/gpui/src/app/test_app.rs with:

- TestApp: test context that auto-runs until parked after updates
- TestAppWindow<V>: window wrapper with input simulation helpers

Minor improvement over TestAppContext/VisualTestContext - mainly
convenience (auto-parking, owned window handle, cleaner signatures).

Does NOT solve the deeper issues:
- Scene is still pub(crate), can't inspect rendered output
- Editor still needs FocusHandle which needs real GPUI context
- TestEditor duplication in ex still exists

3 tests included demonstrating basic usage.
2025-12-14 10:34:21 -07:00
John Tur
6cc947f654 Update cc and cmake crates (#44797)
This fixes the build when Visual Studio 2026 is installed.

Release Notes:

- N/A
2025-12-14 07:45:54 +00:00
Will Garrison
f2cc24c5fa docs: Add clarifying note about Vim subword motion (#44535)
Clarify the docs regarding how operators are affected when subword
motion in Vim is activated.

Ref:
https://github.com/zed-industries/zed/issues/23344#issuecomment-3186025873.

Release Notes:

- N/A

---------

Co-authored-by: Kunall Banerjee <hey@kimchiii.space>
2025-12-14 02:20:33 -05:00
Michael Benfield
488fa02547 Streaming tool use for inline assistant (#44751)
Depends on: https://github.com/zed-industries/zed/pull/44753

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-12-14 03:22:20 +00:00
Cole Miller
dad6481e02 Disambiguate branch name in title bar (#44793)
Add the repository name when:

- there's more than one repository, and
- the name of the active repository doesn't match the name of the
project (to avoid stuttering with the adjacent project switcher button)

Release Notes:

- The branch name in the title bar now includes the name of the current
repository when needed to disambiguate.
2025-12-14 02:51:58 +00:00
Danilo Leal
0283bfb049 Enable configuring edit prediction providers through the settings UI (#44505)
- Edit prediction providers can now be configured through the settings
UI
- Cleaned up the status bar menu to only show _configured_ providers
- Added to the status bar icon button tooltip the name of the active
provider
- Only display the data collection functionality under "Privacy" for the
Zed models
- Moved the Codestral edit prediction provider out of the Mistral
section in the agent panel into the settings UI
- Refined and improved UI and states for configuring GitHub Copilot as
both an agent and edit prediction provider

#### Todos before merge:

- [x] UI: Unify with settings UI style and tidy it all up
- [x] Unify Copilot modal `impl`s to use separate window
- [x] Remove stop light icons from GitHub modal
- [x] Make dismiss events work on GitHub modal
- [ ] Investigate workarounds to tell if Copilot authenticated even when
LSP not running


Release Notes:

- settings_ui: Added a section for configuring edit prediction providers
under AI > Edit Predictions, including Codestral and GitHub Copilot.
Once you've updated you can use the following link to open it:
zed://settings/edit_predictions.providers

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-13 11:06:30 -05:00
Michael Benfield
56daba28d4 supports_streaming_tools member (#44753)
Release Notes:

- N/A
2025-12-13 00:56:06 +00:00
Josh Ayres
6e0ecbcb07 docs: Use relative_line_numbers instead of toggle_relative_line_numbers (#44749)
Just a small docs change

With the deprecation of `toggle_relative_line_numbers` the docs should
reflect that

Release Notes:

- N/A
2025-12-13 00:41:31 +00:00
Haojian Wu
4754422ef4 Add angled bracket highlighting for C++ (#44735)
Enables rainbow bracket highlighting for angle brackets (< >) in C++.

<img width="401" height="46" alt="image"
src="https://github.com/user-attachments/assets/169afdaa-c8be-4b78-bf64-9cf08787eb47"
/>


Release Notes:

- Added rainbow bracket coloring for C++ angle brackets (`<>`)
2025-12-13 01:38:44 +01:00
Marco Mihai Condrache
e860252185 gpui: Improve path rendering and bounds performance (#44655) 2025-12-12 23:01:16 +00:00
Anthony Eid
fad06dd00c git: Show all branches in branch picker empty state (#44742)
This fixes an issue where a user could get confused by the branch picker
because it would only show the 10 most recent branches, instead of all
branches.

Release Notes:

- git: Show all branches in branch picker when search field is empty
2025-12-12 17:59:35 -05:00
Xiaobo Liu
329ec645da gpui: Fix tab jitter from oversized scrolling (#42434) 2025-12-12 22:27:09 +00:00
Oleksiy Syvokon
e1d236eaf0 ep: Apply diff to editable region only and edit history fixes (#44737)
Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-12-12 21:18:13 +00:00
Agus Zubiaga
60f4aa333b edit prediction cli: Improve error handling (#44718)
We were panicking whenever something went wrong with an example in the
CLI. This can be very disruptive when running many examples, and e.g a
single request fails. Instead, if running more than one example, errors
will now be logged alongside instructions to explore and re-run the
example by itself.

<img width="1454" height="744" alt="CleanShot 2025-12-12 at 13 32 04@2x"
src="https://github.com/user-attachments/assets/87c59e64-08b9-4461-af5b-03af5de94152"></img>


You can still opt in to stop as soon as en error occurs with the new
`--failfast` argument.

Release Notes:

- N/A
2025-12-12 14:15:58 -03:00
localcc
a698f1bf63 Fix Bounds::contains (#44711)
Closes #11643 

Release Notes:

- Fixed double hover state on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 14:49:29 +00:00
localcc
636d11ebec Multiple priority scheduler (#44701)
Improves the scheduler by allowing tasks to have a set priority which
will significantly improve responsiveness.

Release notes:

- N/A

---------

Co-authored-by: Yara <git@yara.blue>
Co-authored-by: dvdsk <noreply@davidsk.dev>
2025-12-12 06:32:30 -08:00
Agus Zubiaga
4d0e760b04 edit prediction cli: Progress output cleanup (#44708)
- Limit status lines to 10 in case `max_parallelism` is specified with a
grater value
- Handle logging gracefully rather than writing over it when clearing
status lines

Release Notes:

- N/A
2025-12-12 14:03:08 +00:00
localcc
8bd4d866b9 Windows/send keystrokes (#44707)
Closes #41176 

Release Notes:

- Fixed SendKeystrokes mapping on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 05:51:11 -08:00
114 changed files with 5142 additions and 2557 deletions

22
Cargo.lock generated
View File

@@ -2770,9 +2770,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.41"
version = "1.2.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7"
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
dependencies = [
"find-msvc-tools",
"jobserver",
@@ -3113,9 +3113,9 @@ dependencies = [
[[package]]
name = "cmake"
version = "0.1.54"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586"
dependencies = [
"cc",
]
@@ -5111,7 +5111,6 @@ dependencies = [
"cloud_llm_client",
"collections",
"copilot",
"credentials_provider",
"ctor",
"db",
"edit_prediction_context",
@@ -5201,7 +5200,6 @@ dependencies = [
"wasmtime",
"watch",
"zeta_prompt",
"zlog",
]
[[package]]
@@ -5276,7 +5274,6 @@ dependencies = [
"text",
"theme",
"ui",
"ui_input",
"util",
"workspace",
"zed_actions",
@@ -6094,9 +6091,9 @@ dependencies = [
[[package]]
name = "find-msvc-tools"
version = "0.1.4"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
[[package]]
name = "fixedbitset"
@@ -7240,6 +7237,7 @@ dependencies = [
"libc",
"log",
"lyon",
"mach2 0.5.0",
"media",
"metal",
"naga",
@@ -8802,6 +8800,7 @@ dependencies = [
"cloud_api_types",
"cloud_llm_client",
"collections",
"credentials_provider",
"futures 0.3.31",
"gpui",
"http_client",
@@ -8820,6 +8819,7 @@ dependencies = [
"telemetry_events",
"thiserror 2.0.17",
"util",
"zed_env_vars",
]
[[package]]
@@ -8876,7 +8876,6 @@ dependencies = [
"util",
"vercel",
"x_ai",
"zed_env_vars",
]
[[package]]
@@ -14778,6 +14777,8 @@ dependencies = [
"assets",
"bm25",
"client",
"copilot",
"edit_prediction",
"editor",
"feature_flags",
"fs",
@@ -14786,6 +14787,7 @@ dependencies = [
"gpui",
"heck 0.5.0",
"language",
"language_models",
"log",
"menu",
"node_runtime",

View File

@@ -39,6 +39,5 @@ Only make changes that are necessary to fulfill the prompt, leave everything els
Start at the indentation level in the original file in the rewritten {{content_type}}.
You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if
you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so.
IMPORTANT: You MUST use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. You MUST NOT send back unstructured text. If you need to make a statement or ask a question you MUST use one of the tools to do so.
It is an error if you try to make a change that cannot be made simply by editing the rewrite_section.

View File

@@ -896,6 +896,8 @@
"default_width": 380,
},
"agent": {
// Whether the inline assistant should use streaming tools, when available
"inline_assistant_use_streaming_tools": true,
// Whether the agent is enabled.
"enabled": true,
// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
@@ -1410,8 +1412,9 @@
"proxy_no_verify": null,
},
"codestral": {
"model": null,
"max_tokens": null,
"api_url": "https://codestral.mistral.ai",
"model": "codestral-latest",
"max_tokens": 150,
},
// Whether edit predictions are enabled when editing text threads in the agent panel.
// This setting has no effect if globally disabled.

View File

@@ -28,6 +28,7 @@ pub struct AgentSettings {
pub default_height: Pixels,
pub default_model: Option<LanguageModelSelection>,
pub inline_assistant_model: Option<LanguageModelSelection>,
pub inline_assistant_use_streaming_tools: bool,
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
@@ -155,6 +156,9 @@ impl Settings for AgentSettings {
default_height: px(agent.default_height.unwrap()),
default_model: Some(agent.default_model.unwrap()),
inline_assistant_model: agent.inline_assistant_model,
inline_assistant_use_streaming_tools: agent
.inline_assistant_use_streaming_tools
.unwrap_or(true),
commit_message_model: agent.commit_message_model,
thread_summary_model: agent.thread_summary_model,
inline_alternatives: agent.inline_alternatives.unwrap_or_default(),

View File

@@ -34,9 +34,9 @@ use project::{
};
use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure, Divider,
DividerColor, ElevationIndex, Indicator, LabelSize, PopoverMenu, Switch, Tooltip,
WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};

View File

@@ -445,6 +445,7 @@ mod tests {
default_height: px(600.),
default_model: None,
inline_assistant_model: None,
inline_assistant_use_streaming_tools: false,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: vec![],

View File

@@ -1,23 +1,26 @@
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{LocalBoxFuture, Shared},
join,
stream::BoxStream,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
report_assistant_event,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
LanguageModelToolUse, Role, TokenUsage, report_assistant_event,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
@@ -25,6 +28,7 @@ use prompt_store::PromptBuilder;
use rope::Rope;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings as _;
use smol::future::FutureExt;
use std::{
cmp,
@@ -46,6 +50,7 @@ pub struct FailureMessageInput {
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
///
/// The message may use markdown formatting if you wish.
#[serde(default)]
pub message: String,
}
@@ -56,9 +61,11 @@ pub struct RewriteSectionInput {
///
/// The description may use markdown formatting if you wish.
/// This is optional - if the edit is simple or obvious, you should leave it empty.
#[serde(default)]
pub description: String,
/// The text to replace the section with.
#[serde(default)]
pub replacement_text: String,
}
@@ -379,6 +386,12 @@ impl CodegenAlternative {
&self.last_equal_ranges
}
fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
model.supports_streaming_tools()
&& cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
&& AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
}
pub fn start(
&mut self,
user_prompt: String,
@@ -398,11 +411,17 @@ impl CodegenAlternative {
let telemetry_id = model.telemetry_id();
let provider_id = model.provider_id();
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
if Self::use_streaming_tools(model.as_ref(), cx) {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
let tool_use =
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
let completion_events =
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
self.generation = self.handle_completion(
telemetry_id,
provider_id.to_string(),
api_key,
completion_events,
cx,
);
} else {
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
@@ -414,13 +433,14 @@ impl CodegenAlternative {
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
self.generation =
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
}
Ok(())
}
fn build_request_v2(
fn build_request_tools(
&self,
model: &Arc<dyn LanguageModel>,
user_prompt: String,
@@ -456,7 +476,7 @@ impl CodegenAlternative {
let system_prompt = self
.builder
.generate_inline_transformation_prompt_v2(
.generate_inline_transformation_prompt_tools(
language_name,
buffer,
range.start.0..range.end.0,
@@ -466,6 +486,9 @@ impl CodegenAlternative {
let temperature = AgentSettings::temperature_for_model(model, cx);
let tool_input_format = model.tool_input_format();
let tool_choice = model
.supports_tool_choice(LanguageModelToolChoice::Any)
.then_some(LanguageModelToolChoice::Any);
Ok(cx.spawn(async move |_cx| {
let mut messages = vec![LanguageModelRequestMessage {
@@ -508,7 +531,7 @@ impl CodegenAlternative {
intent: Some(CompletionIntent::InlineAssist),
mode: None,
tools,
tool_choice: None,
tool_choice,
stop: Vec::new(),
temperature,
messages,
@@ -524,8 +547,8 @@ impl CodegenAlternative {
context_task: Shared<Task<Option<LoadedContext>>>,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
return self.build_request_v2(model, user_prompt, context_task, cx);
if Self::use_streaming_tools(model.as_ref(), cx) {
return self.build_request_tools(model, user_prompt, context_task, cx);
}
let buffer = self.buffer.read(cx).snapshot(cx);
@@ -603,7 +626,7 @@ impl CodegenAlternative {
model_api_key: Option<String>,
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let start_time = Instant::now();
// Make a new snapshot and re-resolve anchor in case the document was modified.
@@ -659,7 +682,8 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let token_usage = stream
@@ -685,6 +709,7 @@ impl CodegenAlternative {
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -876,8 +901,7 @@ impl CodegenAlternative {
cx.notify();
})
.ok();
});
cx.notify();
})
}
pub fn current_completion(&self) -> Option<String> {
@@ -1060,21 +1084,29 @@ impl CodegenAlternative {
})
}
fn handle_tool_use(
fn handle_completion(
&mut self,
_telemetry_id: String,
_provider_id: String,
_api_key: Option<String>,
tool_use: impl 'static
+ Future<
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
telemetry_id: String,
provider_id: String,
api_key: Option<String>,
completion_stream: Task<
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
// Leaving this in generation so that STOP equivalent events are respected even
// while we're still pre-processing the completion event
cx.spawn(async move |codegen, cx| {
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
let _ = codegen.update(cx, |this, cx| {
this.status = status;
@@ -1083,76 +1115,176 @@ impl CodegenAlternative {
});
};
let tool_use = tool_use.await;
let mut completion_events = match completion_stream.await {
Ok(events) => events,
Err(err) => {
finish_with_status(CodegenStatus::Error(err.into()), cx);
return;
}
};
match tool_use {
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
// Parse the input JSON into RewriteSectionInput
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
Ok(input) => {
// Store the description if non-empty
let description = if !input.description.trim().is_empty() {
Some(input.description.clone())
} else {
None
let chars_read_so_far = Arc::new(Mutex::new(0usize));
let tool_to_text_and_message =
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
let mut chars_read_so_far = chars_read_so_far.lock();
match tool_use.name.as_ref() {
"rewrite_section" => {
let Ok(mut input) =
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
else {
return (None, None);
};
let value = input.replacement_text[*chars_read_so_far..].to_string();
*chars_read_so_far = input.replacement_text.len();
(Some(value), Some(std::mem::take(&mut input.description)))
}
"failure_message" => {
let Ok(mut input) =
serde_json::from_value::<FailureMessageInput>(tool_use.input)
else {
return (None, None);
};
(None, Some(std::mem::take(&mut input.message)))
}
_ => (None, None),
}
};
// Apply the replacement text to the buffer and compute diff
let batch_diff_task = codegen
.update(cx, |this, cx| {
this.model_explanation = description.map(Into::into);
let range = this.range.clone();
this.apply_edits(
std::iter::once((range, input.replacement_text)),
cx,
);
this.reapply_batch_diff(cx)
})
.ok();
let mut message_id = None;
let mut first_text = None;
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
let total_text = Arc::new(Mutex::new(String::new()));
// Wait for the diff computation to complete
if let Some(diff_task) = batch_diff_task {
diff_task.await;
loop {
if let Some(first_event) = completion_events.next().await {
match first_event {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
message_id = Some(id);
}
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
// Only update the model explanation if the tool use is complete.
// Otherwise the UI element bounces around as it's updated.
if is_complete {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message.map(Into::into);
});
}
finish_with_status(CodegenStatus::Done, cx);
return;
first_text = text;
if first_text.is_some() {
break;
}
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
}
Ok(e) => {
log::warn!("Unexpected event: {:?}", e);
break;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
break;
}
}
}
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
// Handle failure message tool use
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
Ok(input) => {
let _ = codegen.update(cx, |this, _cx| {
// Store the failure message as the tool description
this.model_explanation = Some(input.message.into());
});
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
}
Ok(_tool_use) => {
// Unexpected tool.
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
});
cx.notify();
let Some(first_text) = first_text else {
finish_with_status(CodegenStatus::Done, cx);
return;
};
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
cx.spawn({
let codegen = codegen.clone();
async move |cx| {
while let Some(message) = message_rx.next().await {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message;
});
}
}
})
.detach();
let move_last_token_usage = last_token_usage.clone();
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
completion_events.filter_map(move |e| {
let tool_to_text_and_message = tool_to_text_and_message.clone();
let last_token_usage = move_last_token_usage.clone();
let total_text = total_text.clone();
let mut message_tx = message_tx.clone();
async move {
match e {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
if is_complete {
// Again only send the message when complete to not get a bouncing UI element.
let _ = message_tx.send(message.map(Into::into)).await;
}
text.map(Ok)
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
None
}
Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
e => {
log::error!("UNEXPECTED EVENT {:?}", e);
None
}
}
}
}),
));
let language_model_text_stream = LanguageModelTextStream {
message_id: message_id,
stream: text_stream,
last_token_usage,
};
let Some(task) = codegen
.update(cx, move |codegen, cx| {
codegen.handle_stream(
telemetry_id,
provider_id,
api_key,
async { Ok(language_model_text_stream) },
cx,
)
})
.ok()
else {
return;
};
task.await;
})
}
}
@@ -1679,7 +1811,7 @@ mod tests {
) -> mpsc::UnboundedSender<String> {
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
codegen.generation = codegen.handle_stream(
String::new(),
String::new(),
None,

View File

@@ -1455,60 +1455,8 @@ impl InlineAssistant {
let old_snapshot = codegen.snapshot(cx);
let old_buffer = codegen.old_buffer(cx);
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
// let model_explanation = codegen.model_explanation(cx);
editor.update(cx, |editor, cx| {
// Update tool description block
// if let Some(description) = model_explanation {
// if let Some(block_id) = decorations.model_explanation {
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
// let new_block_id = editor.insert_blocks(
// [BlockProperties {
// style: BlockStyle::Flex,
// placement: BlockPlacement::Below(assist.range.end),
// height: Some(1),
// render: Arc::new({
// let description = description.clone();
// move |cx| {
// div()
// .w_full()
// .py_1()
// .px_2()
// .bg(cx.theme().colors().editor_background)
// .border_y_1()
// .border_color(cx.theme().status().info_border)
// .child(
// Label::new(description.clone())
// .color(Color::Muted)
// .size(LabelSize::Small),
// )
// .into_any_element()
// }
// }),
// priority: 0,
// }],
// None,
// cx,
// );
// decorations.model_explanation = new_block_id.into_iter().next();
// }
// } else if let Some(block_id) = decorations.model_explanation {
// // Hide the block if there's no description
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
// let new_block_id = editor.insert_blocks(
// [BlockProperties {
// style: BlockStyle::Flex,
// placement: BlockPlacement::Below(assist.range.end),
// height: Some(0),
// render: Arc::new(|_cx| div().into_any_element()),
// priority: 0,
// }],
// None,
// cx,
// );
// decorations.model_explanation = new_block_id.into_iter().next();
// }
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
editor.remove_blocks(old_blocks, None, cx);

View File

@@ -429,10 +429,24 @@ impl Model {
let mut headers = vec![];
match self {
Self::ClaudeOpus4
| Self::ClaudeOpus4_1
| Self::ClaudeOpus4_5
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4_5
| Self::ClaudeOpus4Thinking
| Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5Thinking
| Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5Thinking => {
// Fine-grained tool streaming for newer models
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
}
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => {
// Try beta token-efficient tool use (supported in Claude 3.7 Sonnet only)
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
headers.push("token-efficient-tools-2025-02-19".to_string());
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
}
Self::Custom {
extra_beta_headers, ..

View File

@@ -371,6 +371,8 @@ pub struct LanguageModel {
pub supports_images: bool,
pub supports_thinking: bool,
pub supports_max_mode: bool,
#[serde(default)]
pub supports_streaming_tools: bool,
// only used by OpenAI and xAI
#[serde(default)]
pub supports_parallel_tool_calls: bool,

View File

@@ -4,7 +4,7 @@ pub mod copilot_responses;
pub mod request;
mod sign_in;
use crate::sign_in::initiate_sign_in_within_workspace;
use crate::sign_in::initiate_sign_out;
use ::fs::Fs;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
@@ -28,12 +28,10 @@ use project::DisableAiSettings;
use request::StatusNotification;
use semver::Version;
use serde_json::json;
use settings::Settings;
use settings::SettingsStore;
use sign_in::{reinstall_and_sign_in_within_workspace, sign_out_within_workspace};
use std::collections::hash_map::Entry;
use settings::{Settings, SettingsStore};
use std::{
any::TypeId,
collections::hash_map::Entry,
env,
ffi::OsString,
mem,
@@ -42,12 +40,14 @@ use std::{
sync::Arc,
};
use sum_tree::Dimensions;
use util::rel_path::RelPath;
use util::{ResultExt, fs::remove_matching};
use util::{ResultExt, fs::remove_matching, rel_path::RelPath};
use workspace::Workspace;
pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
pub use crate::sign_in::{
ConfigurationMode, ConfigurationView, CopilotCodeVerification, initiate_sign_in,
reinstall_and_sign_in,
};
actions!(
copilot,
@@ -98,21 +98,14 @@ pub fn init(
.detach();
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
workspace.register_action(|workspace, _: &SignIn, window, cx| {
if let Some(copilot) = Copilot::global(cx) {
let is_reinstall = false;
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
}
workspace.register_action(|_, _: &SignIn, window, cx| {
initiate_sign_in(window, cx);
});
workspace.register_action(|workspace, _: &Reinstall, window, cx| {
if let Some(copilot) = Copilot::global(cx) {
reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
}
workspace.register_action(|_, _: &Reinstall, window, cx| {
reinstall_and_sign_in(window, cx);
});
workspace.register_action(|workspace, _: &SignOut, _window, cx| {
if let Some(copilot) = Copilot::global(cx) {
sign_out_within_workspace(workspace, copilot, cx);
}
workspace.register_action(|_, _: &SignOut, window, cx| {
initiate_sign_out(window, cx);
});
})
.detach();
@@ -375,7 +368,7 @@ impl Copilot {
}
}
fn start_copilot(
pub fn start_copilot(
&mut self,
check_edit_prediction_provider: bool,
awaiting_sign_in_after_start: bool,
@@ -563,6 +556,14 @@ impl Copilot {
let server = start_language_server.await;
this.update(cx, |this, cx| {
cx.notify();
if env::var("ZED_FORCE_COPILOT_ERROR").is_ok() {
this.server = CopilotServer::Error(
"Forced error for testing (ZED_FORCE_COPILOT_ERROR)".into(),
);
return;
}
match server {
Ok((server, status)) => {
this.server = CopilotServer::Running(RunningCopilotServer {
@@ -584,7 +585,17 @@ impl Copilot {
.ok();
}
pub(crate) fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
pub fn is_authenticated(&self) -> bool {
return matches!(
self.server,
CopilotServer::Running(RunningCopilotServer {
sign_in_status: SignInStatus::Authorized,
..
})
);
}
pub fn sign_in(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
if let CopilotServer::Running(server) = &mut self.server {
let task = match &server.sign_in_status {
SignInStatus::Authorized => Task::ready(Ok(())).shared(),

View File

@@ -1,160 +1,151 @@
use crate::{Copilot, Status, request::PromptUserDeviceFlow};
use anyhow::Context as _;
use gpui::{
Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent, Element, Entity,
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent,
ParentElement, Render, Styled, Subscription, Transformation, Window, div, percentage, svg,
App, ClipboardItem, Context, DismissEvent, Element, Entity, EventEmitter, FocusHandle,
Focusable, InteractiveElement, IntoElement, MouseDownEvent, ParentElement, Render, Styled,
Subscription, Window, WindowBounds, WindowOptions, div, point,
};
use std::time::Duration;
use ui::{Button, Label, Vector, VectorName, prelude::*};
use ui::{ButtonLike, CommonAnimationExt, ConfiguredApiCard, Vector, VectorName, prelude::*};
use util::ResultExt as _;
use workspace::notifications::NotificationId;
use workspace::{ModalView, Toast, Workspace};
use workspace::{Toast, Workspace, notifications::NotificationId};
const COPILOT_SIGN_UP_URL: &str = "https://github.com/features/copilot";
const ERROR_LABEL: &str =
"Copilot had issues starting. You can try reinstalling it and signing in again.";
struct CopilotStatusToast;
pub fn initiate_sign_in(window: &mut Window, cx: &mut App) {
let is_reinstall = false;
initiate_sign_in_impl(is_reinstall, window, cx)
}
pub fn initiate_sign_out(window: &mut Window, cx: &mut App) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
let Some(workspace) = window.root::<Workspace>().flatten() else {
return;
};
workspace.update(cx, |workspace, cx| {
let is_reinstall = false;
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx)
});
copilot_toast(Some("Signing out of Copilot…"), window, cx);
let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx));
window
.spawn(cx, async move |cx| match sign_out_task.await {
Ok(()) => {
cx.update(|window, cx| copilot_toast(Some("Signed out of Copilot"), window, cx))
}
Err(err) => cx.update(|window, cx| {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.show_error(&err, cx);
})
} else {
log::error!("{:?}", err);
}
}),
})
.detach();
}
pub fn reinstall_and_sign_in(window: &mut Window, cx: &mut App) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx));
let is_reinstall = true;
initiate_sign_in_impl(is_reinstall, window, cx);
}
fn open_copilot_code_verification_window(copilot: &Entity<Copilot>, window: &Window, cx: &mut App) {
let current_window_center = window.bounds().center();
let height = px(450.);
let width = px(350.);
let window_bounds = WindowBounds::Windowed(gpui::bounds(
current_window_center - point(height / 2.0, width / 2.0),
gpui::size(height, width),
));
cx.open_window(
WindowOptions {
kind: gpui::WindowKind::PopUp,
window_bounds: Some(window_bounds),
is_resizable: false,
is_movable: true,
titlebar: Some(gpui::TitlebarOptions {
appears_transparent: true,
..Default::default()
}),
..Default::default()
},
|window, cx| cx.new(|cx| CopilotCodeVerification::new(&copilot, window, cx)),
)
.context("Failed to open Copilot code verification window")
.log_err();
}
fn copilot_toast(message: Option<&'static str>, window: &Window, cx: &mut App) {
const NOTIFICATION_ID: NotificationId = NotificationId::unique::<CopilotStatusToast>();
let Some(workspace) = window.root::<Workspace>().flatten() else {
return;
};
workspace.update(cx, |workspace, cx| {
reinstall_and_sign_in_within_workspace(workspace, copilot, window, cx);
workspace.update(cx, |workspace, cx| match message {
Some(message) => workspace.show_toast(Toast::new(NOTIFICATION_ID, message), cx),
None => workspace.dismiss_toast(&NOTIFICATION_ID, cx),
});
}
pub fn reinstall_and_sign_in_within_workspace(
workspace: &mut Workspace,
copilot: Entity<Copilot>,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
let _ = copilot.update(cx, |copilot, cx| copilot.reinstall(cx));
let is_reinstall = true;
initiate_sign_in_within_workspace(workspace, copilot, is_reinstall, window, cx);
}
pub fn initiate_sign_in_within_workspace(
workspace: &mut Workspace,
copilot: Entity<Copilot>,
is_reinstall: bool,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
pub fn initiate_sign_in_impl(is_reinstall: bool, window: &mut Window, cx: &mut App) {
let Some(copilot) = Copilot::global(cx) else {
return;
};
if matches!(copilot.read(cx).status(), Status::Disabled) {
copilot.update(cx, |copilot, cx| copilot.start_copilot(false, true, cx));
}
match copilot.read(cx).status() {
Status::Starting { task } => {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStatusToast>(),
if is_reinstall {
"Copilot is reinstalling..."
} else {
"Copilot is starting..."
},
),
copilot_toast(
Some(if is_reinstall {
"Copilot is reinstalling…"
} else {
"Copilot is starting"
}),
window,
cx,
);
cx.spawn_in(window, async move |workspace, cx| {
task.await;
if let Some(copilot) = cx.update(|_window, cx| Copilot::global(cx)).ok().flatten() {
workspace
.update_in(cx, |workspace, window, cx| {
match copilot.read(cx).status() {
Status::Authorized => workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStatusToast>(),
"Copilot has started.",
),
cx,
),
_ => {
workspace.dismiss_toast(
&NotificationId::unique::<CopilotStatusToast>(),
cx,
);
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach_and_log_err(cx);
workspace.toggle_modal(window, cx, |_, cx| {
CopilotCodeVerification::new(&copilot, cx)
});
}
window
.spawn(cx, async move |cx| {
task.await;
cx.update(|window, cx| {
let Some(copilot) = Copilot::global(cx) else {
return;
};
match copilot.read(cx).status() {
Status::Authorized => {
copilot_toast(Some("Copilot has started."), window, cx)
}
})
.log_err();
}
})
.detach();
_ => {
copilot_toast(None, window, cx);
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach_and_log_err(cx);
open_copilot_code_verification_window(&copilot, window, cx);
}
}
})
.log_err();
})
.detach();
}
_ => {
copilot
.update(cx, |copilot, cx| copilot.sign_in(cx))
.detach();
workspace.toggle_modal(window, cx, |_, cx| {
CopilotCodeVerification::new(&copilot, cx)
});
open_copilot_code_verification_window(&copilot, window, cx);
}
}
}
pub fn sign_out_within_workspace(
workspace: &mut Workspace,
copilot: Entity<Copilot>,
cx: &mut Context<Workspace>,
) {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStatusToast>(),
"Signing out of Copilot...",
),
cx,
);
let sign_out_task = copilot.update(cx, |copilot, cx| copilot.sign_out(cx));
cx.spawn(async move |workspace, cx| match sign_out_task.await {
Ok(()) => {
workspace
.update(cx, |workspace, cx| {
workspace.show_toast(
Toast::new(
NotificationId::unique::<CopilotStatusToast>(),
"Signed out of Copilot.",
),
cx,
)
})
.ok();
}
Err(err) => {
workspace
.update(cx, |workspace, cx| {
workspace.show_error(&err, cx);
})
.ok();
}
})
.detach();
}
pub struct CopilotCodeVerification {
status: Status,
connect_clicked: bool,
@@ -170,23 +161,27 @@ impl Focusable for CopilotCodeVerification {
}
impl EventEmitter<DismissEvent> for CopilotCodeVerification {}
impl ModalView for CopilotCodeVerification {
fn on_before_dismiss(
&mut self,
_: &mut Window,
cx: &mut Context<Self>,
) -> workspace::DismissDecision {
self.copilot.update(cx, |copilot, cx| {
if matches!(copilot.status(), Status::SigningIn { .. }) {
copilot.sign_out(cx).detach_and_log_err(cx);
}
});
workspace::DismissDecision::Dismiss(true)
}
}
impl CopilotCodeVerification {
pub fn new(copilot: &Entity<Copilot>, cx: &mut Context<Self>) -> Self {
pub fn new(copilot: &Entity<Copilot>, window: &mut Window, cx: &mut Context<Self>) -> Self {
window.on_window_should_close(cx, |window, cx| {
if let Some(this) = window.root::<CopilotCodeVerification>().flatten() {
this.update(cx, |this, cx| {
this.before_dismiss(cx);
});
}
true
});
cx.subscribe_in(
&cx.entity(),
window,
|this, _, _: &DismissEvent, window, cx| {
window.remove_window();
this.before_dismiss(cx);
},
)
.detach();
let status = copilot.read(cx).status();
Self {
status,
@@ -215,45 +210,45 @@ impl CopilotCodeVerification {
.read_from_clipboard()
.map(|item| item.text().as_ref() == Some(&data.user_code))
.unwrap_or(false);
h_flex()
.w_full()
.p_1()
.border_1()
.border_muted(cx)
.rounded_sm()
.cursor_pointer()
.justify_between()
.on_mouse_down(gpui::MouseButton::Left, {
ButtonLike::new("copy-button")
.full_width()
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.size(ButtonSize::Medium)
.child(
h_flex()
.w_full()
.p_1()
.justify_between()
.child(Label::new(data.user_code.clone()))
.child(Label::new(if copied { "Copied!" } else { "Copy" })),
)
.on_click({
let user_code = data.user_code.clone();
move |_, window, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(user_code.clone()));
window.refresh();
}
})
.child(div().flex_1().child(Label::new(data.user_code.clone())))
.child(div().flex_none().px_1().child(Label::new(if copied {
"Copied!"
} else {
"Copy"
})))
}
fn render_prompting_modal(
connect_clicked: bool,
data: &PromptUserDeviceFlow,
cx: &mut Context<Self>,
) -> impl Element {
let connect_button_label = if connect_clicked {
"Waiting for connection..."
"Waiting for connection"
} else {
"Connect to GitHub"
};
v_flex()
.flex_1()
.gap_2()
.gap_2p5()
.items_center()
.child(Headline::new("Use GitHub Copilot in Zed.").size(HeadlineSize::Large))
.text_center()
.child(Headline::new("Use GitHub Copilot in Zed").size(HeadlineSize::Large))
.child(
Label::new("Using Copilot requires an active subscription on GitHub.")
.color(Color::Muted),
@@ -261,83 +256,119 @@ impl CopilotCodeVerification {
.child(Self::render_device_code(data, cx))
.child(
Label::new("Paste this code into GitHub after clicking the button below.")
.size(ui::LabelSize::Small),
.color(Color::Muted),
)
.child(
Button::new("connect-button", connect_button_label)
.on_click({
let verification_uri = data.verification_uri.clone();
cx.listener(move |this, _, _window, cx| {
cx.open_url(&verification_uri);
this.connect_clicked = true;
})
})
.full_width()
.style(ButtonStyle::Filled),
)
.child(
Button::new("copilot-enable-cancel-button", "Cancel")
.full_width()
.on_click(cx.listener(|_, _, _, cx| {
cx.emit(DismissEvent);
})),
v_flex()
.w_full()
.gap_1()
.child(
Button::new("connect-button", connect_button_label)
.full_width()
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
.on_click({
let verification_uri = data.verification_uri.clone();
cx.listener(move |this, _, _window, cx| {
cx.open_url(&verification_uri);
this.connect_clicked = true;
})
}),
)
.child(
Button::new("copilot-enable-cancel-button", "Cancel")
.full_width()
.size(ButtonSize::Medium)
.on_click(cx.listener(|_, _, _, cx| {
cx.emit(DismissEvent);
})),
),
)
}
fn render_enabled_modal(cx: &mut Context<Self>) -> impl Element {
v_flex()
.gap_2()
.text_center()
.justify_center()
.child(Headline::new("Copilot Enabled!").size(HeadlineSize::Large))
.child(Label::new(
"You can update your settings or sign out from the Copilot menu in the status bar.",
))
.child(Label::new("You're all set to use GitHub Copilot.").color(Color::Muted))
.child(
Button::new("copilot-enabled-done-button", "Done")
.full_width()
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
)
}
fn render_unauthorized_modal(cx: &mut Context<Self>) -> impl Element {
v_flex()
.child(Headline::new("You must have an active GitHub Copilot subscription.").size(HeadlineSize::Large))
let description = "Enable Copilot by connecting your existing license once you have subscribed or renewed your subscription.";
.child(Label::new(
"You can enable Copilot by connecting your existing license once you have subscribed or renewed your subscription.",
).color(Color::Warning))
v_flex()
.gap_2()
.text_center()
.justify_center()
.child(
Headline::new("You must have an active GitHub Copilot subscription.")
.size(HeadlineSize::Large),
)
.child(Label::new(description).color(Color::Warning))
.child(
Button::new("copilot-subscribe-button", "Subscribe on GitHub")
.full_width()
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
.on_click(|_, _, cx| cx.open_url(COPILOT_SIGN_UP_URL)),
)
.child(
Button::new("copilot-subscribe-cancel-button", "Cancel")
.full_width()
.size(ButtonSize::Medium)
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
)
}
fn render_loading(window: &mut Window, _: &mut Context<Self>) -> impl Element {
let loading_icon = svg()
.size_8()
.path(IconName::ArrowCircle.path())
.text_color(window.text_style().color)
.with_animation(
"icon_circle_arrow",
Animation::new(Duration::from_secs(2)).repeat(),
|svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
);
fn render_error_modal(_cx: &mut Context<Self>) -> impl Element {
v_flex()
.gap_2()
.text_center()
.justify_center()
.child(Headline::new("An Error Happened").size(HeadlineSize::Large))
.child(Label::new(ERROR_LABEL).color(Color::Muted))
.child(
Button::new("copilot-subscribe-button", "Reinstall Copilot and Sign In")
.full_width()
.style(ButtonStyle::Outlined)
.size(ButtonSize::Medium)
.icon(IconName::Download)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.on_click(|_, window, cx| reinstall_and_sign_in(window, cx)),
)
}
h_flex().justify_center().child(loading_icon)
fn before_dismiss(
&mut self,
cx: &mut Context<'_, CopilotCodeVerification>,
) -> workspace::DismissDecision {
self.copilot.update(cx, |copilot, cx| {
if matches!(copilot.status(), Status::SigningIn { .. }) {
copilot.sign_out(cx).detach_and_log_err(cx);
}
});
workspace::DismissDecision::Dismiss(true)
}
}
impl Render for CopilotCodeVerification {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let prompt = match &self.status {
Status::SigningIn { prompt: None } => {
Self::render_loading(window, cx).into_any_element()
}
Status::SigningIn { prompt: None } => Icon::new(IconName::ArrowCircle)
.color(Color::Muted)
.with_rotate_animation(2)
.into_any_element(),
Status::SigningIn {
prompt: Some(prompt),
} => Self::render_prompting_modal(self.connect_clicked, prompt, cx).into_any_element(),
@@ -349,17 +380,20 @@ impl Render for CopilotCodeVerification {
self.connect_clicked = false;
Self::render_enabled_modal(cx).into_any_element()
}
Status::Error(..) => Self::render_error_modal(cx).into_any_element(),
_ => div().into_any_element(),
};
v_flex()
.id("copilot code verification")
.id("copilot_code_verification")
.track_focus(&self.focus_handle(cx))
.elevation_3(cx)
.w_96()
.items_center()
.p_4()
.size_full()
.px_4()
.py_8()
.gap_2()
.items_center()
.justify_center()
.elevation_3(cx)
.on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
cx.emit(DismissEvent);
}))
@@ -373,3 +407,243 @@ impl Render for CopilotCodeVerification {
.child(prompt)
}
}
pub struct ConfigurationView {
copilot_status: Option<Status>,
is_authenticated: fn(cx: &App) -> bool,
edit_prediction: bool,
_subscription: Option<Subscription>,
}
pub enum ConfigurationMode {
Chat,
EditPrediction,
}
impl ConfigurationView {
pub fn new(
is_authenticated: fn(cx: &App) -> bool,
mode: ConfigurationMode,
cx: &mut Context<Self>,
) -> Self {
let copilot = Copilot::global(cx);
Self {
copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
is_authenticated,
edit_prediction: matches!(mode, ConfigurationMode::EditPrediction),
_subscription: copilot.as_ref().map(|copilot| {
cx.observe(copilot, |this, model, cx| {
this.copilot_status = Some(model.read(cx).status());
cx.notify();
})
}),
}
}
}
impl ConfigurationView {
fn is_starting(&self) -> bool {
matches!(&self.copilot_status, Some(Status::Starting { .. }))
}
fn is_signing_in(&self) -> bool {
matches!(
&self.copilot_status,
Some(Status::SigningIn { .. })
| Some(Status::SignedOut {
awaiting_signing_in: true
})
)
}
fn is_error(&self) -> bool {
matches!(&self.copilot_status, Some(Status::Error(_)))
}
fn has_no_status(&self) -> bool {
self.copilot_status.is_none()
}
fn loading_message(&self) -> Option<SharedString> {
if self.is_starting() {
Some("Starting Copilot…".into())
} else if self.is_signing_in() {
Some("Signing into Copilot…".into())
} else {
None
}
}
fn render_loading_button(
&self,
label: impl Into<SharedString>,
edit_prediction: bool,
) -> impl IntoElement {
ButtonLike::new("loading_button")
.disabled(true)
.style(ButtonStyle::Outlined)
.when(edit_prediction, |this| this.size(ButtonSize::Medium))
.child(
h_flex()
.w_full()
.gap_1()
.justify_center()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Muted)
.with_rotate_animation(4),
)
.child(Label::new(label)),
)
}
fn render_sign_in_button(&self, edit_prediction: bool) -> impl IntoElement {
let label = if edit_prediction {
"Sign in to GitHub"
} else {
"Sign in to use GitHub Copilot"
};
Button::new("sign_in", label)
.map(|this| {
if edit_prediction {
this.size(ButtonSize::Medium)
} else {
this.full_width()
}
})
.style(ButtonStyle::Outlined)
.icon(IconName::Github)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.on_click(|_, window, cx| initiate_sign_in(window, cx))
}
fn render_reinstall_button(&self, edit_prediction: bool) -> impl IntoElement {
let label = if edit_prediction {
"Reinstall and Sign in"
} else {
"Reinstall Copilot and Sign in"
};
Button::new("reinstall_and_sign_in", label)
.map(|this| {
if edit_prediction {
this.size(ButtonSize::Medium)
} else {
this.full_width()
}
})
.style(ButtonStyle::Outlined)
.icon(IconName::Download)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.on_click(|_, window, cx| reinstall_and_sign_in(window, cx))
}
fn render_for_edit_prediction(&self) -> impl IntoElement {
let container = |description: SharedString, action: AnyElement| {
h_flex()
.pt_2p5()
.w_full()
.justify_between()
.child(
v_flex()
.w_full()
.max_w_1_2()
.child(Label::new("Authenticate To Use"))
.child(
Label::new(description)
.color(Color::Muted)
.size(LabelSize::Small),
),
)
.child(action)
};
let start_label = "To use Copilot for edit predictions, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot subscription.".into();
let no_status_label = "Copilot requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different edit predictions provider.".into();
if let Some(msg) = self.loading_message() {
container(
start_label,
self.render_loading_button(msg, true).into_any_element(),
)
.into_any_element()
} else if self.is_error() {
container(
ERROR_LABEL.into(),
self.render_reinstall_button(true).into_any_element(),
)
.into_any_element()
} else if self.has_no_status() {
container(
no_status_label,
self.render_sign_in_button(true).into_any_element(),
)
.into_any_element()
} else {
container(
start_label,
self.render_sign_in_button(true).into_any_element(),
)
.into_any_element()
}
}
fn render_for_chat(&self) -> impl IntoElement {
let start_label = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
let no_status_label = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different LLM provider.";
if let Some(msg) = self.loading_message() {
v_flex()
.gap_2()
.child(Label::new(start_label))
.child(self.render_loading_button(msg, false))
.into_any_element()
} else if self.is_error() {
v_flex()
.gap_2()
.child(Label::new(ERROR_LABEL))
.child(self.render_reinstall_button(false))
.into_any_element()
} else if self.has_no_status() {
v_flex()
.gap_2()
.child(Label::new(no_status_label))
.child(self.render_sign_in_button(false))
.into_any_element()
} else {
v_flex()
.gap_2()
.child(Label::new(start_label))
.child(self.render_sign_in_button(false))
.into_any_element()
}
}
}
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.is_authenticated;
if is_authenticated(cx) {
return ConfiguredApiCard::new("Authorized")
.button_label("Sign Out")
.on_click(|_, window, cx| {
initiate_sign_out(window, cx);
})
.into_any_element();
}
if self.edit_prediction {
self.render_for_edit_prediction().into_any_element()
} else {
self.render_for_chat().into_any_element()
}
}
}

View File

@@ -23,7 +23,6 @@ client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
db.workspace = true
edit_prediction_types.workspace = true
edit_prediction_context.workspace = true

View File

@@ -72,6 +72,7 @@ pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use language_model::ApiKeyState;
pub use telemetry_events::EditPredictionRating;
pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
@@ -536,22 +537,12 @@ impl EditPredictionStore {
self.edit_prediction_model = model;
}
pub fn has_sweep_api_token(&self) -> bool {
self.sweep_ai
.api_token
.clone()
.now_or_never()
.flatten()
.is_some()
pub fn has_sweep_api_token(&self, cx: &App) -> bool {
self.sweep_ai.api_token.read(cx).has_key()
}
pub fn has_mercury_api_token(&self) -> bool {
self.mercury
.api_token
.clone()
.now_or_never()
.flatten()
.is_some()
pub fn has_mercury_api_token(&self, cx: &App) -> bool {
self.mercury.api_token.read(cx).has_key()
}
#[cfg(feature = "cli-support")]
@@ -586,10 +577,11 @@ impl EditPredictionStore {
pub fn edit_history_for_project(
&self,
project: &Entity<Project>,
cx: &App,
) -> Vec<Arc<zeta_prompt::Event>> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events.iter().cloned().collect())
.map(|project_state| project_state.events(cx))
.unwrap_or_default()
}

View File

@@ -1,40 +1,34 @@
use anyhow::{Context as _, Result};
use credentials_provider::CredentialsProvider;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
use crate::{
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Entity, SharedString, Task,
http_client::{self, AsyncBody, Method},
};
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
use language_model::{ApiKeyState, EnvVar, env_var};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
pub struct Mercury {
pub api_token: Shared<Task<Option<String>>>,
pub api_token: Entity<ApiKeyState>,
}
impl Mercury {
pub fn new(cx: &App) -> Self {
pub fn new(cx: &mut App) -> Self {
Mercury {
api_token: load_api_token(cx).shared(),
api_token: mercury_api_token(cx),
}
}
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
self.api_token = Task::ready(api_token.clone()).shared();
store_api_token_in_keychain(api_token, cx)
}
pub(crate) fn request_prediction(
&self,
EditPredictionModelInput {
@@ -48,7 +42,10 @@ impl Mercury {
}: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
});
let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
@@ -299,45 +296,16 @@ fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(
prompt.push_str(delimiters.end);
}
pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
pub const MERCURY_CREDENTIALS_URL: SharedString =
SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
pub static MERCURY_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
.ok()
.filter(|value| !value.is_empty())
{
return Task::ready(Some(api_token));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
let (_, credentials) = credentials_provider
.read_credentials(MERCURY_CREDENTIALS_URL, &cx)
.await
.ok()??;
String::from_utf8(credentials).ok()
})
}
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
if let Some(api_token) = api_token {
credentials_provider
.write_credentials(
MERCURY_CREDENTIALS_URL,
MERCURY_CREDENTIALS_USERNAME,
api_token.as_bytes(),
cx,
)
.await
.context("Failed to save Mercury API token to system keychain")
} else {
credentials_provider
.delete_credentials(MERCURY_CREDENTIALS_URL, cx)
.await
.context("Failed to delete Mercury API token from system keychain")
}
})
pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
MERCURY_API_KEY
.get_or_init(|| {
cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()))
})
.clone()
}

View File

@@ -1,11 +1,11 @@
use anyhow::{Context as _, Result};
use credentials_provider::CredentialsProvider;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use anyhow::Result;
use futures::AsyncReadExt as _;
use gpui::{
App, AppContext as _, Task,
App, AppContext as _, Entity, SharedString, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Point, ToOffset as _};
use language_model::{ApiKeyState, EnvVar, env_var};
use lsp::DiagnosticSeverity;
use serde::{Deserialize, Serialize};
use std::{
@@ -20,30 +20,28 @@ use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredicti
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
pub struct SweepAi {
pub api_token: Shared<Task<Option<String>>>,
pub api_token: Entity<ApiKeyState>,
pub debug_info: Arc<str>,
}
impl SweepAi {
pub fn new(cx: &App) -> Self {
pub fn new(cx: &mut App) -> Self {
SweepAi {
api_token: load_api_token(cx).shared(),
api_token: sweep_api_token(cx),
debug_info: debug_info(cx),
}
}
pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
self.api_token = Task::ready(api_token.clone()).shared();
store_api_token_in_keychain(api_token, cx)
}
pub fn request_prediction_with_sweep(
&self,
inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
self.api_token.update(cx, |key_state, cx| {
_ = key_state.load_if_needed(SWEEP_CREDENTIALS_URL, |s| s, cx);
});
let Some(api_token) = self.api_token.read(cx).key(&SWEEP_CREDENTIALS_URL) else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = inputs
@@ -270,47 +268,18 @@ impl SweepAi {
}
}
pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
pub const SWEEP_CREDENTIALS_URL: SharedString =
SharedString::new_static("https://autocomplete.sweep.dev");
pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
pub static SWEEP_AI_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("SWEEP_AI_TOKEN");
pub static SWEEP_API_KEY: std::sync::OnceLock<Entity<ApiKeyState>> = std::sync::OnceLock::new();
pub fn load_api_token(cx: &App) -> Task<Option<String>> {
if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
.ok()
.filter(|value| !value.is_empty())
{
return Task::ready(Some(api_token));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
let (_, credentials) = credentials_provider
.read_credentials(SWEEP_CREDENTIALS_URL, &cx)
.await
.ok()??;
String::from_utf8(credentials).ok()
})
}
fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
if let Some(api_token) = api_token {
credentials_provider
.write_credentials(
SWEEP_CREDENTIALS_URL,
SWEEP_CREDENTIALS_USERNAME,
api_token.as_bytes(),
cx,
)
.await
.context("Failed to save Sweep API token to system keychain")
} else {
credentials_provider
.delete_credentials(SWEEP_CREDENTIALS_URL, cx)
.await
.context("Failed to delete Sweep API token from system keychain")
}
})
pub fn sweep_api_token(cx: &mut App) -> Entity<ApiKeyState> {
SWEEP_API_KEY
.get_or_init(|| {
cx.new(|_| ApiKeyState::new(SWEEP_CREDENTIALS_URL, SWEEP_AI_TOKEN_ENV_VAR.clone()))
})
.clone()
}
#[derive(Debug, Clone, Serialize)]

View File

@@ -100,7 +100,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
) -> bool {
let store = self.store.read(cx);
if store.edit_prediction_model == EditPredictionModel::Sweep {
store.has_sweep_api_token()
store.has_sweep_api_token(cx)
} else {
true
}

View File

@@ -228,13 +228,16 @@ pub fn zeta2_prompt_input(
}
#[cfg(feature = "cli-support")]
pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String {
eprintln!("{}", patch);
eprintln!("---------------------");
eprintln!("{}", input.cursor_excerpt);
crate::udiff::apply_diff_to_string(
patch,
&input.cursor_excerpt[input.editable_range_in_excerpt.clone()],
)
.unwrap()
pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
let text = &input.cursor_excerpt;
let editable_region = input.editable_range_in_excerpt.clone();
let old_prefix = &text[..editable_region.start];
let old_suffix = &text[editable_region.end..];
let new = crate::udiff::apply_diff_to_string(patch, text)?;
if !new.starts_with(old_prefix) || !new.ends_with(old_suffix) {
anyhow::bail!("Patch shouldn't affect text outside of editable region");
}
Ok(new[editable_region.start..new.len() - old_suffix.len()].to_string())
}

View File

@@ -56,7 +56,6 @@ watch.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
zlog.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.

View File

@@ -1,14 +1,22 @@
use anyhow::{Result, anyhow};
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) {
let [prediction]: [_; 1] = mem::take(&mut example.predictions)
.try_into()
.expect("Run predict first with a single repetition");
pub async fn run_distill(example: &mut Example) -> Result<()> {
let [prediction]: [_; 1] =
mem::take(&mut example.predictions)
.try_into()
.map_err(|preds: Vec<_>| {
anyhow!(
"Example has {} predictions, but it should have exactly one",
preds.len()
)
})?;
example.expected_patch = prediction.actual_patch;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
Ok(())
}

View File

@@ -6,6 +6,7 @@ use crate::{
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result, ensure};
use edit_prediction::{
EditPredictionStore,
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
@@ -18,12 +19,11 @@ pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
) -> Result<()> {
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
let _step_progress = progress.start(Step::FormatPrompt, &example.name);
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
match prompt_format {
PromptFormat::Teacher => {
@@ -35,31 +35,35 @@ pub async fn run_format_prompt(
});
}
PromptFormat::Zeta2 => {
run_load_project(example, app_state, progress.clone(), cx.clone()).await;
run_load_project(example, app_state, cx.clone()).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let state = example.state.as_ref().unwrap();
let snapshot = state
.buffer
.read_with(&cx, |buffer, _| buffer.snapshot())
.unwrap();
let state = example.state.as_ref().context("state must be set")?;
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store
.update(&mut cx, |ep_store, _cx| {
zeta2_prompt_input(
&snapshot,
example.context.as_ref().unwrap().files.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example.buffer.as_ref().unwrap().cursor_offset,
)
})
.unwrap();
let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
.context
.as_ref()
.context("context must be set")?
.files
.clone(),
ep_store.edit_history_for_project(&project, cx),
example.cursor_path.clone(),
example
.buffer
.as_ref()
.context("buffer must be set")?
.cursor_offset,
))
})??;
let prompt = format_zeta_prompt(&input);
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone())?;
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
@@ -67,6 +71,7 @@ pub async fn run_format_prompt(
});
}
};
Ok(())
}
pub struct TeacherPrompt;
@@ -92,7 +97,7 @@ impl TeacherPrompt {
prompt
}
pub fn parse(example: &Example, response: &str) -> String {
pub fn parse(example: &Example, response: &str) -> Result<String> {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
@@ -103,7 +108,7 @@ impl TeacherPrompt {
let cursor_file = &example
.buffer
.as_ref()
.expect("`buffer` should be filled in in the context collection step")
.context("`buffer` should be filled in in the context collection step")?
.content;
// Extract updated (new) editable region from the model response
@@ -112,9 +117,10 @@ impl TeacherPrompt {
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
if !cursor_file.contains(&old_editable_region) {
panic!("Something's wrong: editable_region is not found in the cursor file")
}
ensure!(
cursor_file.contains(&old_editable_region),
"Something's wrong: editable_region is not found in the cursor file"
);
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
@@ -129,7 +135,7 @@ impl TeacherPrompt {
diff = diff,
};
diff
Ok(diff)
}
fn format_edit_history(edit_history: &str) -> String {
@@ -153,9 +159,7 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
if example.context.is_none() {
panic!("Missing context retriever step");
}
assert!(example.context.is_some(), "Missing context retriever step");
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);

View File

@@ -4,7 +4,7 @@ use crate::{
paths::{REPOS_DIR, WORKTREES_DIR},
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
@@ -28,40 +28,35 @@ use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(
example: &mut Example,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> Result<()> {
if example.state.is_some() {
return;
return Ok(());
}
let progress = progress.start(Step::LoadProject, &example.name);
let progress = Progress::global().start(Step::LoadProject, &example.name);
let project = setup_project(example, &app_state, &progress, &mut cx).await;
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
let (example_buffer, language_name) = buffer
.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})
.unwrap();
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})?;
progress.set_info(language_name, InfoStyle::Normal);
@@ -72,16 +67,15 @@ pub async fn run_load_project(
cursor_position,
_open_buffers,
});
Ok(())
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> (Entity<Buffer>, Anchor) {
let language_registry = project
.read_with(cx, |project, _| project.languages().clone())
.unwrap();
) -> Result<(Entity<Buffer>, Anchor)> {
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(&example.cursor_path)
.await;
@@ -89,17 +83,18 @@ async fn cursor_position(
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
panic!("Failed to load language for file path: {}", error);
return Err(error);
}
let worktree = project
.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})
.unwrap();
let worktree = project.read_with(cx, |project, cx| {
project
.visible_worktrees(cx)
.next()
.context("No visible worktrees")
})??;
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.unwrap()
.context("Failed to create RelPath")?
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
@@ -110,15 +105,12 @@ async fn cursor_position(
},
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))
.unwrap();
.context("missing cursor marker")?;
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
@@ -128,90 +120,76 @@ async fn cursor_position(
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
panic!(
let (excerpt_offset, _) = matches.next().with_context(|| {
format!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
example.name
);
});
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
excerpt_offset
}).unwrap();
)
})?;
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor = cursor_buffer
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
.unwrap();
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
(cursor_buffer, cursor_anchor)
Ok((cursor_buffer, cursor_anchor))
}
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
step_progress: &Arc<StepProgress>,
step_progress: &StepProgress,
cx: &mut AsyncApp,
) -> Entity<Project> {
) -> Result<Entity<Project>> {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
.update(|cx| EditPredictionStore::try_global(cx))?
.context("Store should be initialized at init")?;
let worktree_path = setup_worktree(example, step_progress).await;
let worktree_path = setup_worktree(example, step_progress).await?;
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
ep_store
.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})
.unwrap();
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffers = buffer_store
.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})
.unwrap();
ep_store.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})?;
for buffer in buffers {
buffer
.update(cx, |buffer, cx| buffer.reload(cx))
.unwrap()
.update(cx, |buffer, cx| buffer.reload(cx))?
.await
.ok();
}
return project;
return Ok(project);
}
let project = cx
.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})
.unwrap();
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
project.create_worktree(&worktree_path, true, cx)
})
.unwrap()
.await
.unwrap();
})?
.await?;
app_state
.project_cache
.insert(example.repository_url.clone(), project.clone());
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
@@ -220,15 +198,14 @@ async fn setup_project(
}
_ => {}
}
})
.unwrap()
})?
.detach();
project
Ok(project)
}
async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
@@ -237,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await
.unwrap();
.await?;
}
// Resolve the example to a revision, fetching it if needed.
@@ -264,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
.await
.unwrap();
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
revision
};
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
.unwrap();
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
.await
.unwrap();
run_git(&worktree_path, &["checkout", revision.as_str()])
.await
.unwrap();
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await
.unwrap();
.await?;
run_git(
&repo_dir,
&[
@@ -302,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
&example.name,
],
)
.await
.unwrap();
.await?;
}
drop(repo_lock);
@@ -314,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()
.unwrap();
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(example.uncommitted_diff.as_bytes())
.await
.unwrap();
stdin.close().await.unwrap();
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await.unwrap();
if !apply_result.status.success() {
panic!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
let apply_result = apply_process.output().await?;
anyhow::ensure!(
apply_result.status.success(),
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
step_progress.clear_substatus();
worktree_path
Ok(worktree_path)
}
async fn apply_edit_history(

View File

@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
@@ -32,7 +34,7 @@ use crate::score::run_scoring;
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
#[clap(long, default_value_t = 10)]
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
@@ -42,6 +44,8 @@ struct EpArgs {
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
#[arg(long, short, global = true)]
failfast: bool,
}
#[derive(Subcommand, Debug)]
@@ -67,6 +71,58 @@ enum Command {
Clean,
}
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(format_prompt_args) => write!(
f,
"format-prompt --prompt-format={}",
format_prompt_args
.prompt_format
.to_possible_value()
.unwrap()
.get_name()
),
Command::Predict(predict_args) => {
write!(
f,
"predict --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Score(predict_args) => {
write!(
f,
"score --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Distill => write!(f, "distill"),
Command::Eval(predict_args) => write!(
f,
"eval --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Clean => write!(f, "clean"),
}
}
}
#[derive(Debug, Args)]
struct FormatPromptArgs {
#[clap(long)]
@@ -112,8 +168,6 @@ impl EpArgs {
}
fn main() {
let _ = zlog::try_init(Some("error".into()));
zlog::init_output_stderr();
let args = EpArgs::parse();
if args.printenv {
@@ -147,92 +201,140 @@ fn main() {
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await
};
let result = async {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
let total_examples = examples.len();
let progress = Progress::new(total_examples);
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(
example,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
let result = async {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(example, app_state.clone(), cx.clone())
.await?;
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Distill => {
run_distill(example).await?;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state.clone(), cx.clone())
.await?;
}
Command::Clean => {
unreachable!()
}
}
anyhow::Ok(())
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Distill => {
run_distill(example).await;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(
example,
&args,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Clean => {
unreachable!()
.await;
if let Err(e) = result {
Progress::global().increment_failed();
let failed_example_path =
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
app_state
.fs
.write(
&failed_example_path,
&serde_json::to_vec_pretty(&example).unwrap(),
)
.await
.unwrap();
let err_path =
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
app_state
.fs
.write(&err_path, e.to_string().as_bytes())
.await
.unwrap();
let msg = format!(
indoc::indoc! {"
While processing {}:
{:?}
Written to: \x1b[36m{}\x1b[0m
Explore this example data with:
fx \x1b[36m{}\x1b[0m
Re-run this example with:
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.name,
e,
err_path.display(),
failed_example_path.display(),
command,
failed_example_path.display(),
);
if args.failfast || total_examples == 1 {
Progress::global().finalize();
panic!("{}", msg);
} else {
log::error!("{}", msg);
}
}
}
}
});
futures::future::join_all(futures).await;
}
progress.clear();
});
futures::future::join_all(futures).await;
}
Progress::global().finalize();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
anyhow::Ok(())
}
.await;
if let Err(e) = result {
panic!("Fatal error: {:?}", e);
}
let _ = cx.update(|cx| cx.quit());
})

View File

@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");

View File

@@ -9,6 +9,7 @@ use crate::{
progress::{InfoStyle, Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
@@ -25,41 +26,33 @@ pub async fn run_prediction(
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if !example.predictions.is_empty() {
return;
return Ok(());
}
let provider = provider.unwrap();
let provider = provider.context("provider is required")?;
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
if matches!(
provider,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
) {
let _step_progress = progress.start(Step::Predict, &example.name);
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
run_format_prompt(
example,
PromptFormat::Teacher,
app_state.clone(),
progress,
cx,
)
.await;
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let _step_progress = progress.start(Step::Predict, &example.name);
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if matches!(
provider,
@@ -70,10 +63,9 @@ pub async fn run_prediction(
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
eprintln!("Authentication failed: {}", e);
}
})
.shared()
})
@@ -81,33 +73,30 @@ pub async fn run_prediction(
.await;
}
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
ep_store
.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})
.unwrap();
let state = example.state.as_ref().unwrap();
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})?;
let state = example.state.as_ref().context("state must be set")?;
let run_dir = RUN_DIR.join(&example.name);
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
let mut debug_rx = ep_store
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
.unwrap();
let mut debug_rx =
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
@@ -161,14 +150,14 @@ pub async fn run_prediction(
run_dir.clone()
};
fs::create_dir_all(&run_dir).unwrap();
fs::create_dir_all(&run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
updated_example
.lock()
@@ -189,10 +178,8 @@ pub async fn run_prediction(
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let actual_patch = prediction
.and_then(|prediction| {
@@ -221,20 +208,23 @@ pub async fn run_prediction(
}
}
ep_store
.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})
.unwrap();
debug_task.await.unwrap();
ep_store.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})?;
debug_task.await?;
*example = Arc::into_inner(updated_example)
.unwrap()
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
.into_inner()
.unwrap();
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
Ok(())
}
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
async fn predict_anthropic(
example: &mut Example,
_repetition_count: usize,
batched: bool,
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
@@ -242,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.expect("Failed to create LLM client");
let llm_client = llm_client.context("Failed to create LLM client")?;
let prompt = example
.prompt
.as_ref()
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
let prompt = example.prompt.as_ref().context("Prompt is required")?;
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
@@ -259,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await
.unwrap()
.await?
else {
// Request stashed for batched processing
return;
return Ok(());
};
let actual_output = response
@@ -276,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output);
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch,
@@ -285,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
};
example.predictions.push(prediction);
Ok(())
}
pub async fn sync_batches(provider: &PredictionProvider) {
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
llm_client
.sync_batches()
.await
.expect("Failed to sync batches");
.context("Failed to sync batches")?;
}
_ => (),
}
};
Ok(())
}

View File

@@ -2,10 +2,12 @@ use std::{
borrow::Cow,
collections::HashMap,
io::{IsTerminal, Write},
sync::{Arc, Mutex},
sync::{Arc, Mutex, OnceLock},
time::{Duration, Instant},
};
use log::{Level, Log, Metadata, Record};
pub struct Progress {
inner: Mutex<ProgressInner>,
}
@@ -18,6 +20,8 @@ struct ProgressInner {
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
failed_examples: usize,
last_line_is_logging: bool,
}
#[derive(Clone)]
@@ -72,70 +76,120 @@ impl Step {
}
}
const RIGHT_MARGIN: usize = 4;
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
static LOGGER: ProgressLogger = ProgressLogger;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
impl Progress {
pub fn new(total_examples: usize) -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(ProgressInner {
completed: Vec::new(),
in_progress: HashMap::new(),
is_tty: std::io::stderr().is_terminal(),
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples,
}),
})
/// Returns the global Progress instance, initializing it if necessary.
pub fn global() -> Arc<Progress> {
GLOBAL
.get_or_init(|| {
let progress = Arc::new(Self {
inner: Mutex::new(ProgressInner {
completed: Vec::new(),
in_progress: HashMap::new(),
is_tty: std::io::stderr().is_terminal(),
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
failed_examples: 0,
last_line_is_logging: false,
}),
});
let _ = log::set_logger(&LOGGER);
log::set_max_level(log::LevelFilter::Error);
progress
})
.clone()
}
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> Arc<StepProgress> {
{
let mut inner = self.inner.lock().unwrap();
pub fn set_total_examples(&self, total: usize) {
let mut inner = self.inner.lock().unwrap();
inner.total_examples = total;
}
Self::clear_status_lines(&mut inner);
pub fn increment_failed(&self) {
let mut inner = self.inner.lock().unwrap();
inner.failed_examples += 1;
}
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
/// This should be used for any output that needs to appear above the status lines.
fn log(&self, message: &str) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
step,
started_at: Instant::now(),
substatus: None,
info: None,
},
);
Self::print_status_lines(&mut inner);
if !inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = true;
}
Arc::new(StepProgress {
eprintln!("{}", message);
}
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
step,
started_at: Instant::now(),
substatus: None,
info: None,
},
);
Self::print_status_lines(&mut inner);
StepProgress {
progress: self.clone(),
step,
example_name: example_name.to_string(),
})
}
}
pub fn finish(&self, step: Step, example_name: &str) {
fn finish(&self, step: Step, example_name: &str) {
let mut inner = self.inner.lock().unwrap();
let task = inner.in_progress.remove(example_name);
if let Some(task) = task {
if task.step == step {
inner.completed.push(CompletedTask {
step: task.step,
example_name: example_name.to_string(),
duration: task.started_at.elapsed(),
info: task.info,
});
let Some(task) = inner.in_progress.remove(example_name) else {
return;
};
Self::clear_status_lines(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
}
if task.step == step {
inner.completed.push(CompletedTask {
step: task.step,
example_name: example_name.to_string(),
duration: task.started_at.elapsed(),
info: task.info,
});
Self::clear_status_lines(&mut inner);
Self::print_logging_closing_divider(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
}
}
fn print_logging_closing_divider(inner: &mut ProgressInner) {
if inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = false;
}
}
@@ -182,7 +236,7 @@ impl Progress {
let duration_with_margin = format!("{duration} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -216,27 +270,41 @@ impl Progress {
// Build the done/in-progress/total label
let done_count = inner.completed.len();
let in_progress_count = inner.in_progress.len();
let failed_count = inner.failed_examples;
let failed_label = if failed_count > 0 {
format!(" {} failed ", failed_count)
} else {
String::new()
};
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
);
// Print a divider line with range label aligned with timestamps
// Print a divider line with failed count on left, range label on right
let failed_visible_len = strip_ansi_len(&failed_label);
let range_visible_len = range_label.len();
let left_divider_len = inner
let middle_divider_len = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN * 2)
.saturating_sub(failed_visible_len)
.saturating_sub(range_visible_len);
let left_divider = "".repeat(left_divider_len);
let right_divider = "".repeat(RIGHT_MARGIN);
eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
let left_divider = "".repeat(MARGIN);
let middle_divider = "".repeat(middle_divider_len);
let right_divider = "".repeat(MARGIN);
eprintln!(
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
);
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
let total_tasks = tasks.len();
let mut lines_printed = 0;
for (name, task) in tasks.iter() {
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
let elapsed = format_duration(task.started_at.elapsed());
let substatus_part = task
.substatus
@@ -256,7 +324,7 @@ impl Progress {
let duration_with_margin = format!("{elapsed} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -265,13 +333,34 @@ impl Progress {
lines_printed += 1;
}
// Show "+N more" on its own line if there are more tasks
if total_tasks > MAX_STATUS_LINES {
let remaining = total_tasks - MAX_STATUS_LINES;
eprintln!("{:>12} +{remaining} more", "");
lines_printed += 1;
}
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
let _ = std::io::stderr().flush();
}
pub fn clear(&self) {
pub fn finalize(&self) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
// Print summary if there were failures
if inner.failed_examples > 0 {
let total_processed = inner.completed.len() + inner.failed_examples;
let percentage = if total_processed > 0 {
inner.failed_examples as f64 / total_processed as f64 * 100.0
} else {
0.0
};
eprintln!(
"\n{} of {} examples failed ({:.1}%)",
inner.failed_examples, total_processed, percentage
);
}
}
}
@@ -314,6 +403,53 @@ impl Drop for StepProgress {
}
}
struct ProgressLogger;
impl Log for ProgressLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= Level::Info
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
let level_color = match record.level() {
Level::Error => "\x1b[31m",
Level::Warn => "\x1b[33m",
Level::Info => "\x1b[32m",
Level::Debug => "\x1b[34m",
Level::Trace => "\x1b[35m",
};
let reset = "\x1b[0m";
let bold = "\x1b[1m";
let level_label = match record.level() {
Level::Error => "Error",
Level::Warn => "Warn",
Level::Info => "Info",
Level::Debug => "Debug",
Level::Trace => "Trace",
};
let message = format!(
"{bold}{level_color}{level_label:>12}{reset} {}",
record.args()
);
if let Some(progress) = GLOBAL.get() {
progress.log(&message);
} else {
eprintln!("{}", message);
}
}
fn flush(&self) {
let _ = std::io::stderr().flush();
}
}
#[cfg(unix)]
fn get_terminal_width() -> usize {
unsafe {

View File

@@ -4,6 +4,7 @@ use crate::{
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::Context as _;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
@@ -16,39 +17,36 @@ use std::time::Duration;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if example.context.is_some() {
return;
return Ok(());
}
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let step_progress = progress.start(Step::Context, &example.name);
let step_progress: Arc<StepProgress> = Progress::global()
.start(Step::Context, &example.name)
.into();
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project
.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})
.unwrap();
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
let _lsp_handle = project.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})?;
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let mut events = ep_store
.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})
.unwrap();
let mut events = ep_store.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})?;
while let Some(event) = events.next().await {
match event {
@@ -59,9 +57,8 @@ pub async fn run_context_retrieval(
}
}
let context_files = ep_store
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
.unwrap();
let context_files =
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
@@ -69,6 +66,7 @@ pub async fn run_context_retrieval(
example.context = Some(ExampleContext {
files: context_files,
});
Ok(())
}
async fn wait_for_language_servers_to_start(
@@ -76,10 +74,8 @@ async fn wait_for_language_servers_to_start(
buffer: &Entity<Buffer>,
step_progress: &Arc<StepProgress>,
cx: &mut AsyncApp,
) {
let lsp_store = project
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
) -> anyhow::Result<()> {
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
@@ -122,7 +118,7 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
@@ -131,8 +127,7 @@ async fn wait_for_language_servers_to_start(
if !language_server_ids.is_empty() {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.detach();
}
@@ -174,10 +169,8 @@ async fn wait_for_language_servers_to_start(
];
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.await
.unwrap();
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
@@ -188,11 +181,12 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(subscriptions);
step_progress.clear_substatus();
Ok(())
}

View File

@@ -14,20 +14,18 @@ pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
cx: AsyncApp,
) {
) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
progress.clone(),
cx,
)
.await;
.await?;
let _progress = progress.start(Step::Score, &example.name);
let _progress = Progress::global().start(Step::Score, &example.name);
let expected_patch = parse_patch(&example.expected_patch);
@@ -45,6 +43,7 @@ pub async fn run_scoring(
}
example.score = scores;
Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {

View File

@@ -20,8 +20,8 @@ cloud_llm_client.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
edit_prediction.workspace = true
edit_prediction_types.workspace = true
edit_prediction.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -41,7 +41,6 @@ telemetry.workspace = true
text.workspace = true
theme.workspace = true
ui.workspace = true
ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -3,7 +3,9 @@ use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
use codestral::CodestralEditPredictionDelegate;
use copilot::{Copilot, Status};
use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
use edit_prediction::{
EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag,
};
use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
@@ -42,12 +44,9 @@ use workspace::{
StatusItemView, Toast, Workspace, create_and_open_local_file, item::ItemHandle,
notifications::NotificationId,
};
use zed_actions::OpenBrowser;
use zed_actions::{OpenBrowser, OpenSettingsAt};
use crate::{
ExternalProviderApiKeyModal, RatePredictions,
rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
};
use crate::{RatePredictions, rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag};
actions!(
edit_prediction,
@@ -248,45 +247,21 @@ impl Render for EditPredictionButton {
EditPredictionProvider::Codestral => {
let enabled = self.editor_enabled.unwrap_or(true);
let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
let fs = self.fs.clone();
let this = cx.weak_entity();
let tooltip_meta = if has_api_key {
"Powered by Codestral"
} else {
"Missing API key for Codestral"
};
div().child(
PopoverMenu::new("codestral")
.menu(move |window, cx| {
if has_api_key {
this.update(cx, |this, cx| {
this.build_codestral_context_menu(window, cx)
})
.ok()
} else {
Some(ContextMenu::build(window, cx, |menu, _, _| {
let fs = fs.clone();
menu.entry(
"Configure Codestral API Key",
None,
move |window, cx| {
window.dispatch_action(
zed_actions::agent::OpenSettings.boxed_clone(),
cx,
);
},
)
.separator()
.entry(
"Use Zed AI instead",
None,
move |_, cx| {
set_completion_provider(
fs.clone(),
cx,
EditPredictionProvider::Zed,
)
},
)
}))
}
this.update(cx, |this, cx| {
this.build_codestral_context_menu(window, cx)
})
.ok()
})
.anchor(Corner::BottomRight)
.trigger_with_tooltip(
@@ -304,7 +279,14 @@ impl Render for EditPredictionButton {
cx.theme().colors().status_bar_background,
))
}),
move |_window, cx| Tooltip::for_action("Codestral", &ToggleMenu, cx),
move |_window, cx| {
Tooltip::with_meta(
"Edit Prediction",
Some(&ToggleMenu),
tooltip_meta,
cx,
)
},
)
.with_handle(self.popover_menu_handle.clone()),
)
@@ -313,6 +295,7 @@ impl Render for EditPredictionButton {
let enabled = self.editor_enabled.unwrap_or(true);
let ep_icon;
let tooltip_meta;
let mut missing_token = false;
match provider {
@@ -320,15 +303,25 @@ impl Render for EditPredictionButton {
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
ep_icon = IconName::SweepAi;
tooltip_meta = if missing_token {
"Missing API key for Sweep"
} else {
"Powered by Sweep"
};
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
.is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token(cx));
}
EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
) => {
ep_icon = IconName::Inception;
missing_token = edit_prediction::EditPredictionStore::try_global(cx)
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
.is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token(cx));
tooltip_meta = if missing_token {
"Missing API key for Mercury"
} else {
"Powered by Mercury"
};
}
_ => {
ep_icon = if enabled {
@@ -336,6 +329,7 @@ impl Render for EditPredictionButton {
} else {
IconName::ZedPredictDisabled
};
tooltip_meta = "Powered by Zeta"
}
};
@@ -400,33 +394,26 @@ impl Render for EditPredictionButton {
})
.when(!self.popover_menu_handle.is_deployed(), |element| {
let user = user.clone();
element.tooltip(move |_window, cx| {
if enabled {
let description = if enabled {
if show_editor_predictions {
Tooltip::for_action("Edit Prediction", &ToggleMenu, cx)
tooltip_meta
} else if user.is_none() {
Tooltip::with_meta(
"Edit Prediction",
Some(&ToggleMenu),
"Sign In To Use",
cx,
)
"Sign In To Use"
} else {
Tooltip::with_meta(
"Edit Prediction",
Some(&ToggleMenu),
"Hidden For This File",
cx,
)
"Hidden For This File"
}
} else {
Tooltip::with_meta(
"Edit Prediction",
Some(&ToggleMenu),
"Disabled For This File",
cx,
)
}
"Disabled For This File"
};
Tooltip::with_meta(
"Edit Prediction",
Some(&ToggleMenu),
description,
cx,
)
})
});
@@ -519,6 +506,12 @@ impl EditPredictionButton {
providers.push(EditPredictionProvider::Zed);
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
));
}
if let Some(copilot) = Copilot::global(cx) {
if matches!(copilot.read(cx).status(), Status::Authorized) {
providers.push(EditPredictionProvider::Copilot);
@@ -537,24 +530,28 @@ impl EditPredictionButton {
providers.push(EditPredictionProvider::Codestral);
}
if cx.has_flag::<SweepFeatureFlag>() {
let ep_store = EditPredictionStore::try_global(cx);
if cx.has_flag::<SweepFeatureFlag>()
&& ep_store
.as_ref()
.is_some_and(|ep_store| ep_store.read(cx).has_sweep_api_token(cx))
{
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
));
}
if cx.has_flag::<MercuryFeatureFlag>() {
if cx.has_flag::<MercuryFeatureFlag>()
&& ep_store
.as_ref()
.is_some_and(|ep_store| ep_store.read(cx).has_mercury_api_token(cx))
{
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
));
}
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
));
}
providers
}
@@ -562,13 +559,10 @@ impl EditPredictionButton {
&self,
mut menu: ContextMenu,
current_provider: EditPredictionProvider,
cx: &App,
cx: &mut App,
) -> ContextMenu {
let available_providers = self.get_available_providers(cx);
const ZED_AI_CALLOUT: &str =
"Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
let providers: Vec<_> = available_providers
.into_iter()
.filter(|p| *p != EditPredictionProvider::None)
@@ -581,153 +575,32 @@ impl EditPredictionButton {
let is_current = provider == current_provider;
let fs = self.fs.clone();
menu = match provider {
EditPredictionProvider::Zed => menu.item(
ContextMenuEntry::new("Zed AI")
.toggleable(IconPosition::Start, is_current)
.documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
|_| Label::new(ZED_AI_CALLOUT).into_any_element(),
)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
),
EditPredictionProvider::Copilot => menu.item(
ContextMenuEntry::new("GitHub Copilot")
.toggleable(IconPosition::Start, is_current)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
),
EditPredictionProvider::Supermaven => menu.item(
ContextMenuEntry::new("Supermaven")
.toggleable(IconPosition::Start, is_current)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
),
EditPredictionProvider::Codestral => menu.item(
ContextMenuEntry::new("Codestral")
.toggleable(IconPosition::Start, is_current)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
),
let name = match provider {
EditPredictionProvider::Zed => "Zed AI",
EditPredictionProvider::Copilot => "GitHub Copilot",
EditPredictionProvider::Supermaven => "Supermaven",
EditPredictionProvider::Codestral => "Codestral",
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
.map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
let should_open_modal = !has_api_token || is_current;
let entry = if has_api_token {
ContextMenuEntry::new("Sweep")
.toggleable(IconPosition::Start, is_current)
} else {
ContextMenuEntry::new("Sweep")
.icon(IconName::XCircle)
.icon_color(Color::Error)
.documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
|_| {
Label::new("Click to configure your Sweep API token")
.into_any_element()
},
)
};
let entry = entry.handler(move |window, cx| {
if should_open_modal {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
ExternalProviderApiKeyModal::new(
window,
cx,
|api_key, store, cx| {
store
.sweep_ai
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
},
)
});
});
};
} else {
set_completion_provider(fs.clone(), cx, provider);
}
});
menu.item(entry)
}
) => "Sweep",
EditPredictionProvider::Experimental(
EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
) => {
let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
.map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
let should_open_modal = !has_api_token || is_current;
let entry = if has_api_token {
ContextMenuEntry::new("Mercury")
.toggleable(IconPosition::Start, is_current)
} else {
ContextMenuEntry::new("Mercury")
.icon(IconName::XCircle)
.icon_color(Color::Error)
.documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
|_| {
Label::new("Click to configure your Mercury API token")
.into_any_element()
},
)
};
let entry = entry.handler(move |window, cx| {
if should_open_modal {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
ExternalProviderApiKeyModal::new(
window,
cx,
|api_key, store, cx| {
store
.mercury
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
},
)
});
});
};
} else {
set_completion_provider(fs.clone(), cx, provider);
}
});
menu.item(entry)
}
) => "Mercury",
EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
) => menu.item(
ContextMenuEntry::new("Zeta2")
.toggleable(IconPosition::Start, is_current)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
),
) => "Zeta2",
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
continue;
}
};
menu = menu.item(
ContextMenuEntry::new(name)
.toggleable(IconPosition::Start, is_current)
.handler(move |_, cx| {
set_completion_provider(fs.clone(), cx, provider);
}),
)
}
}
@@ -832,14 +705,7 @@ impl EditPredictionButton {
let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle);
let eager_mode = matches!(current_mode, EditPredictionsMode::Eager);
if matches!(
provider,
EditPredictionProvider::Zed
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven
| EditPredictionProvider::Codestral
) {
menu = menu
menu = menu
.separator()
.header("Display Modes")
.item(
@@ -868,104 +734,111 @@ impl EditPredictionButton {
}
}),
);
}
menu = menu.separator().header("Privacy");
if let Some(provider) = &self.edit_prediction_provider {
let data_collection = provider.data_collection_state(cx);
if matches!(
provider,
EditPredictionProvider::Zed
| EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
)
) {
if let Some(provider) = &self.edit_prediction_provider {
let data_collection = provider.data_collection_state(cx);
if data_collection.is_supported() {
let provider = provider.clone();
let enabled = data_collection.is_enabled();
let is_open_source = data_collection.is_project_open_source();
let is_collecting = data_collection.is_enabled();
let (icon_name, icon_color) = if is_open_source && is_collecting {
(IconName::Check, Color::Success)
} else {
(IconName::Check, Color::Accent)
};
if data_collection.is_supported() {
let provider = provider.clone();
let enabled = data_collection.is_enabled();
let is_open_source = data_collection.is_project_open_source();
let is_collecting = data_collection.is_enabled();
let (icon_name, icon_color) = if is_open_source && is_collecting {
(IconName::Check, Color::Success)
} else {
(IconName::Check, Color::Accent)
};
menu = menu.item(
ContextMenuEntry::new("Training Data Collection")
.toggleable(IconPosition::Start, data_collection.is_enabled())
.icon(icon_name)
.icon_color(icon_color)
.documentation_aside(DocumentationSide::Left, DocumentationEdge::Top, move |cx| {
let (msg, label_color, icon_name, icon_color) = match (is_open_source, is_collecting) {
(true, true) => (
"Project identified as open source, and you're sharing data.",
Color::Default,
IconName::Check,
Color::Success,
),
(true, false) => (
"Project identified as open source, but you're not sharing data.",
Color::Muted,
IconName::Close,
Color::Muted,
),
(false, true) => (
"Project not identified as open source. No data captured.",
Color::Muted,
IconName::Close,
Color::Muted,
),
(false, false) => (
"Project not identified as open source, and setting turned off.",
Color::Muted,
IconName::Close,
Color::Muted,
),
};
v_flex()
.gap_2()
.child(
Label::new(indoc!{
"Help us improve our open dataset model by sharing data from open source repositories. \
Zed must detect a license file in your repo for this setting to take effect. \
Files with sensitive data and secrets are excluded by default."
})
)
.child(
h_flex()
.items_start()
.pt_2()
.pr_1()
.flex_1()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.child(h_flex().flex_shrink_0().h(line_height).child(Icon::new(icon_name).size(IconSize::XSmall).color(icon_color)))
.child(div().child(msg).w_full().text_sm().text_color(label_color.color(cx)))
)
.into_any_element()
})
.handler(move |_, cx| {
provider.toggle_data_collection(cx);
if !enabled {
telemetry::event!(
"Data Collection Enabled",
source = "Edit Prediction Status Menu"
);
} else {
telemetry::event!(
"Data Collection Disabled",
source = "Edit Prediction Status Menu"
);
}
})
);
if is_collecting && !is_open_source {
menu = menu.item(
ContextMenuEntry::new("No data captured.")
.disabled(true)
.icon(IconName::Close)
.icon_color(Color::Error)
.icon_size(IconSize::Small),
ContextMenuEntry::new("Training Data Collection")
.toggleable(IconPosition::Start, data_collection.is_enabled())
.icon(icon_name)
.icon_color(icon_color)
.documentation_aside(DocumentationSide::Left, DocumentationEdge::Top, move |cx| {
let (msg, label_color, icon_name, icon_color) = match (is_open_source, is_collecting) {
(true, true) => (
"Project identified as open source, and you're sharing data.",
Color::Default,
IconName::Check,
Color::Success,
),
(true, false) => (
"Project identified as open source, but you're not sharing data.",
Color::Muted,
IconName::Close,
Color::Muted,
),
(false, true) => (
"Project not identified as open source. No data captured.",
Color::Muted,
IconName::Close,
Color::Muted,
),
(false, false) => (
"Project not identified as open source, and setting turned off.",
Color::Muted,
IconName::Close,
Color::Muted,
),
};
v_flex()
.gap_2()
.child(
Label::new(indoc!{
"Help us improve our open dataset model by sharing data from open source repositories. \
Zed must detect a license file in your repo for this setting to take effect. \
Files with sensitive data and secrets are excluded by default."
})
)
.child(
h_flex()
.items_start()
.pt_2()
.pr_1()
.flex_1()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.child(h_flex().flex_shrink_0().h(line_height).child(Icon::new(icon_name).size(IconSize::XSmall).color(icon_color)))
.child(div().child(msg).w_full().text_sm().text_color(label_color.color(cx)))
)
.into_any_element()
})
.handler(move |_, cx| {
provider.toggle_data_collection(cx);
if !enabled {
telemetry::event!(
"Data Collection Enabled",
source = "Edit Prediction Status Menu"
);
} else {
telemetry::event!(
"Data Collection Disabled",
source = "Edit Prediction Status Menu"
);
}
})
);
if is_collecting && !is_open_source {
menu = menu.item(
ContextMenuEntry::new("No data captured.")
.disabled(true)
.icon(IconName::Close)
.icon_color(Color::Error)
.icon_size(IconSize::Small),
);
}
}
}
}
@@ -1087,10 +960,7 @@ impl EditPredictionButton {
let menu =
self.add_provider_switching_section(menu, EditPredictionProvider::Codestral, cx);
menu.separator()
.entry("Configure Codestral API Key", None, move |window, cx| {
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
})
menu
})
}
@@ -1210,6 +1080,22 @@ impl EditPredictionButton {
}
menu = self.add_provider_switching_section(menu, provider, cx);
menu = menu.separator().item(
ContextMenuEntry::new("Configure Providers")
.icon(IconName::Settings)
.icon_position(IconPosition::Start)
.icon_color(Color::Muted)
.handler(move |window, cx| {
window.dispatch_action(
OpenSettingsAt {
path: "edit_predictions.providers".to_string(),
}
.boxed_clone(),
cx,
);
}),
);
menu
})
}

View File

@@ -1,6 +1,5 @@
mod edit_prediction_button;
mod edit_prediction_context_view;
mod external_provider_api_token_modal;
mod rate_prediction_modal;
use std::any::{Any as _, TypeId};
@@ -17,7 +16,6 @@ use ui::{App, prelude::*};
use workspace::{SplitDirection, Workspace};
pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;

View File

@@ -1,86 +0,0 @@
use edit_prediction::EditPredictionStore;
use gpui::{
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
};
use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
use ui_input::InputField;
use workspace::ModalView;
pub struct ExternalProviderApiKeyModal {
api_key_input: Entity<InputField>,
focus_handle: FocusHandle,
on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
}
impl ExternalProviderApiKeyModal {
pub fn new(
window: &mut Window,
cx: &mut Context<Self>,
on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
) -> Self {
let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
Self {
api_key_input,
focus_handle: cx.focus_handle(),
on_confirm: Box::new(on_confirm),
}
}
fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_input.read(cx).text(cx);
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
if let Some(ep_store) = EditPredictionStore::try_global(cx) {
ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
}
cx.emit(DismissEvent);
}
}
impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
impl ModalView for ExternalProviderApiKeyModal {}
impl Focusable for ExternalProviderApiKeyModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl Render for ExternalProviderApiKeyModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.key_context("ExternalApiKeyModal")
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::confirm))
.elevation_2(cx)
.w(px(400.))
.p_4()
.gap_3()
.child(Headline::new("API Token").size(HeadlineSize::Small))
.child(self.api_key_input.clone())
.child(
h_flex()
.justify_end()
.gap_2()
.child(Button::new("cancel", "Cancel").on_click(cx.listener(
|_, _, _window, cx| {
cx.emit(DismissEvent);
},
)))
.child(
Button::new("save", "Save")
.style(ButtonStyle::Filled)
.on_click(cx.listener(|this, _, window, cx| {
this.confirm(&menu::Confirm, window, cx);
})),
),
)
}
}

View File

@@ -12,10 +12,10 @@ impl FeatureFlag for PanicFeatureFlag {
const NAME: &'static str = "panic";
}
pub struct InlineAssistantV2FeatureFlag;
pub struct InlineAssistantUseToolFeatureFlag;
impl FeatureFlag for InlineAssistantV2FeatureFlag {
const NAME: &'static str = "inline-assistant-v2";
impl FeatureFlag for InlineAssistantUseToolFeatureFlag {
const NAME: &'static str = "inline-assistant-use-tool";
fn enabled_for_staff() -> bool {
false

View File

@@ -636,7 +636,6 @@ impl PickerDelegate for BranchListDelegate {
return Task::ready(());
};
const RECENT_BRANCHES_COUNT: usize = 10;
let display_remotes = self.display_remotes;
cx.spawn_in(window, async move |picker, cx| {
let mut matches: Vec<Entry> = if query.is_empty() {
@@ -649,7 +648,6 @@ impl PickerDelegate for BranchListDelegate {
!branch.is_remote()
}
})
.take(RECENT_BRANCHES_COUNT)
.map(|branch| Entry::Branch {
branch,
positions: Vec::new(),

View File

@@ -21,7 +21,6 @@ default = ["font-kit", "wayland", "x11", "windows-manifest"]
test-support = [
"leak-detection",
"collections/test-support",
"rand",
"util/test-support",
"http_client/test-support",
"wayland",
@@ -109,7 +108,7 @@ parking = "2.0.0"
parking_lot.workspace = true
postage.workspace = true
profiling.workspace = true
rand = { optional = true, workspace = true }
rand.workspace = true
raw-window-handle = "0.6"
refineable.workspace = true
resvg = { version = "0.45.0", default-features = false, features = [
@@ -158,8 +157,10 @@ media.workspace = true
objc.workspace = true
objc2 = { version = "0.6", optional = true }
objc2-metal = { version = "0.3", optional = true }
mach2.workspace = true
#TODO: replace with "objc2"
metal.workspace = true
flume = "0.11"
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))'.dependencies]
pathfinder_geometry = "0.5"

View File

@@ -84,6 +84,8 @@ mod macos {
.allowlist_var("_dispatch_main_q")
.allowlist_var("_dispatch_source_type_data_add")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_HIGH")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_LOW")
.allowlist_var("DISPATCH_TIME_NOW")
.allowlist_function("dispatch_get_global_queue")
.allowlist_function("dispatch_async_f")

View File

@@ -28,6 +28,8 @@ pub use entity_map::*;
use http_client::{HttpClient, Url};
use smallvec::SmallVec;
#[cfg(any(test, feature = "test-support"))]
pub use test_app::*;
#[cfg(any(test, feature = "test-support"))]
pub use test_context::*;
use util::{ResultExt, debug_panic};
@@ -38,10 +40,11 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, Priority,
PromptBuilder, PromptButton, PromptHandle, PromptLevel, Render, RenderImage,
RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@@ -50,6 +53,8 @@ mod async_context;
mod context;
mod entity_map;
#[cfg(any(test, feature = "test-support"))]
mod test_app;
#[cfg(any(test, feature = "test-support"))]
mod test_context;
/// The duration for which futures returned from [Context::on_app_quit] can run before the application fully quits.
@@ -1494,6 +1499,24 @@ impl App {
.spawn(async move { f(&mut cx).await })
}
/// Spawns the future returned by the given function on the main thread with
/// the given priority. The closure will be invoked with [AsyncApp], which
/// allows the application state to be accessed across await points.
pub fn spawn_with_priority<AsyncFn, R>(&self, priority: Priority, f: AsyncFn) -> Task<R>
where
AsyncFn: AsyncFnOnce(&mut AsyncApp) -> R + 'static,
R: 'static,
{
if self.quitting {
debug_panic!("Can't spawn on main thread after on_app_quit")
};
let mut cx = self.to_async();
self.foreground_executor
.spawn_with_priority(priority, async move { f(&mut cx).await })
}
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
/// that are currently on the stack to be returned to the app.
pub fn defer(&mut self, f: impl FnOnce(&mut App) + 'static) {

View File

@@ -1,7 +1,7 @@
use crate::{
AnyView, AnyWindowHandle, AppContext, AsyncApp, DispatchPhase, Effect, EntityId, EventEmitter,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Reservation, SubscriberSet,
Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Priority, Reservation,
SubscriberSet, Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
};
use anyhow::Result;
use futures::FutureExt;
@@ -667,6 +667,25 @@ impl<'a, T: 'static> Context<'a, T> {
window.spawn(self, async move |cx| f(view, cx).await)
}
/// Schedule a future to be run asynchronously with the given priority.
/// The given callback is invoked with a [`WeakEntity<V>`] to avoid leaking the entity for a long-running process.
/// It's also given an [`AsyncWindowContext`], which can be used to access the state of the entity across await points.
/// The returned future will be polled on the main thread.
#[track_caller]
pub fn spawn_in_with_priority<AsyncFn, R>(
&self,
priority: Priority,
window: &Window,
f: AsyncFn,
) -> Task<R>
where
R: 'static,
AsyncFn: AsyncFnOnce(WeakEntity<T>, &mut AsyncWindowContext) -> R + 'static,
{
let view = self.weak_entity();
window.spawn_with_priority(priority, self, async move |cx| f(view, cx).await)
}
/// Register a callback to be invoked when the given global state changes.
pub fn observe_global_in<G: Global>(
&mut self,

View File

@@ -0,0 +1,596 @@
//! A clean testing API for GPUI applications.
//!
//! `TestApp` provides a simpler alternative to `TestAppContext` with:
//! - Automatic effect flushing after updates
//! - Clean window creation and inspection
//! - Input simulation helpers
//!
//! # Example
//! ```ignore
//! #[test]
//! fn test_my_view() {
//! let mut app = TestApp::new();
//!
//! let mut window = app.open_window(|window, cx| {
//! MyView::new(window, cx)
//! });
//!
//! window.update(|view, window, cx| {
//! view.do_something(cx);
//! });
//!
//! // Check rendered state
//! assert_eq!(window.title(), Some("Expected Title"));
//! }
//! ```
use crate::{
AnyWindowHandle, App, AppCell, AppContext, AsyncApp, BackgroundExecutor, BorrowAppContext,
Bounds, ClipboardItem, Context, Entity, ForegroundExecutor, Global, InputEvent, Keystroke,
MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels, Platform, Point, Render,
SceneSnapshot, Size, Task, TestDispatcher, TestPlatform, TextSystem, Window, WindowBounds,
WindowHandle, WindowOptions, app::GpuiMode,
};
use rand::{SeedableRng, rngs::StdRng};
use std::{future::Future, rc::Rc, sync::Arc, time::Duration};
/// A test application context with a clean API.
///
/// Unlike `TestAppContext`, `TestApp` automatically flushes effects after
/// each update and provides simpler window management.
pub struct TestApp {
app: Rc<AppCell>,
platform: Rc<TestPlatform>,
background_executor: BackgroundExecutor,
foreground_executor: ForegroundExecutor,
#[allow(dead_code)]
dispatcher: TestDispatcher,
text_system: Arc<TextSystem>,
}
impl TestApp {
/// Create a new test application.
pub fn new() -> Self {
Self::with_seed(0)
}
/// Create a new test application with a specific random seed.
pub fn with_seed(seed: u64) -> Self {
let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(seed));
let arc_dispatcher = Arc::new(dispatcher.clone());
let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
let platform = TestPlatform::new(background_executor.clone(), foreground_executor.clone());
let asset_source = Arc::new(());
let http_client = http_client::FakeHttpClient::with_404_response();
let text_system = Arc::new(TextSystem::new(platform.text_system()));
let mut app = App::new_app(platform.clone(), asset_source, http_client);
app.borrow_mut().mode = GpuiMode::test();
Self {
app,
platform,
background_executor,
foreground_executor,
dispatcher,
text_system,
}
}
/// Run a closure with mutable access to the App context.
/// Automatically runs until parked after the closure completes.
pub fn update<R>(&mut self, f: impl FnOnce(&mut App) -> R) -> R {
let result = {
let mut app = self.app.borrow_mut();
app.update(f)
};
self.run_until_parked();
result
}
/// Run a closure with read-only access to the App context.
pub fn read<R>(&self, f: impl FnOnce(&App) -> R) -> R {
let app = self.app.borrow();
f(&app)
}
/// Create a new entity in the app.
pub fn new_entity<T: 'static>(
&mut self,
build: impl FnOnce(&mut Context<T>) -> T,
) -> Entity<T> {
self.update(|cx| cx.new(build))
}
/// Update an entity.
pub fn update_entity<T: 'static, R>(
&mut self,
entity: &Entity<T>,
f: impl FnOnce(&mut T, &mut Context<T>) -> R,
) -> R {
self.update(|cx| entity.update(cx, f))
}
/// Read an entity.
pub fn read_entity<T: 'static, R>(
&self,
entity: &Entity<T>,
f: impl FnOnce(&T, &App) -> R,
) -> R {
self.read(|cx| f(entity.read(cx), cx))
}
/// Open a test window with the given root view.
pub fn open_window<V: Render + 'static>(
&mut self,
build_view: impl FnOnce(&mut Window, &mut Context<V>) -> V,
) -> TestWindow<V> {
let bounds = self.read(|cx| Bounds::maximized(None, cx));
let handle = self.update(|cx| {
cx.open_window(
WindowOptions {
window_bounds: Some(WindowBounds::Windowed(bounds)),
..Default::default()
},
|window, cx| cx.new(|cx| build_view(window, cx)),
)
.unwrap()
});
TestWindow {
handle,
app: self.app.clone(),
platform: self.platform.clone(),
background_executor: self.background_executor.clone(),
}
}
/// Open a test window with specific options.
pub fn open_window_with_options<V: Render + 'static>(
&mut self,
options: WindowOptions,
build_view: impl FnOnce(&mut Window, &mut Context<V>) -> V,
) -> TestWindow<V> {
let handle = self.update(|cx| {
cx.open_window(options, |window, cx| cx.new(|cx| build_view(window, cx)))
.unwrap()
});
TestWindow {
handle,
app: self.app.clone(),
platform: self.platform.clone(),
background_executor: self.background_executor.clone(),
}
}
/// Run pending tasks until there's nothing left to do.
pub fn run_until_parked(&self) {
self.background_executor.run_until_parked();
}
/// Advance the simulated clock by the given duration.
pub fn advance_clock(&self, duration: Duration) {
self.background_executor.advance_clock(duration);
}
/// Spawn a future on the foreground executor.
pub fn spawn<Fut, R>(&self, f: impl FnOnce(AsyncApp) -> Fut) -> Task<R>
where
Fut: Future<Output = R> + 'static,
R: 'static,
{
self.foreground_executor.spawn(f(self.to_async()))
}
/// Spawn a future on the background executor.
pub fn background_spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
where
R: Send + 'static,
{
self.background_executor.spawn(future)
}
/// Get an async handle to the app.
pub fn to_async(&self) -> AsyncApp {
AsyncApp {
app: Rc::downgrade(&self.app),
background_executor: self.background_executor.clone(),
foreground_executor: self.foreground_executor.clone(),
}
}
/// Get the background executor.
pub fn background_executor(&self) -> &BackgroundExecutor {
&self.background_executor
}
/// Get the foreground executor.
pub fn foreground_executor(&self) -> &ForegroundExecutor {
&self.foreground_executor
}
/// Get the text system.
pub fn text_system(&self) -> &Arc<TextSystem> {
&self.text_system
}
/// Check if a global of the given type exists.
pub fn has_global<G: Global>(&self) -> bool {
self.read(|cx| cx.has_global::<G>())
}
/// Set a global value.
pub fn set_global<G: Global>(&mut self, global: G) {
self.update(|cx| cx.set_global(global));
}
/// Read a global value.
pub fn read_global<G: Global, R>(&self, f: impl FnOnce(&G, &App) -> R) -> R {
self.read(|cx| f(cx.global(), cx))
}
/// Update a global value.
pub fn update_global<G: Global, R>(&mut self, f: impl FnOnce(&mut G, &mut App) -> R) -> R {
self.update(|cx| cx.update_global(f))
}
// Platform simulation methods
/// Write text to the simulated clipboard.
pub fn write_to_clipboard(&self, item: ClipboardItem) {
self.platform.write_to_clipboard(item);
}
/// Read from the simulated clipboard.
pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
self.platform.read_from_clipboard()
}
/// Get URLs that have been opened via `cx.open_url()`.
pub fn opened_url(&self) -> Option<String> {
self.platform.opened_url.borrow().clone()
}
/// Check if a file path prompt is pending.
pub fn did_prompt_for_new_path(&self) -> bool {
self.platform.did_prompt_for_new_path()
}
/// Simulate answering a path selection dialog.
pub fn simulate_new_path_selection(
&self,
select: impl FnOnce(&std::path::Path) -> Option<std::path::PathBuf>,
) {
self.platform.simulate_new_path_selection(select);
}
/// Check if a prompt dialog is pending.
pub fn has_pending_prompt(&self) -> bool {
self.platform.has_pending_prompt()
}
/// Simulate answering a prompt dialog.
pub fn simulate_prompt_answer(&self, button: &str) {
self.platform.simulate_prompt_answer(button);
}
/// Get all open windows.
pub fn windows(&self) -> Vec<AnyWindowHandle> {
self.read(|cx| cx.windows())
}
}
impl Default for TestApp {
fn default() -> Self {
Self::new()
}
}
/// A test window with inspection and simulation capabilities.
pub struct TestWindow<V> {
handle: WindowHandle<V>,
app: Rc<AppCell>,
platform: Rc<TestPlatform>,
background_executor: BackgroundExecutor,
}
impl<V: 'static + Render> TestWindow<V> {
/// Get the window handle.
pub fn handle(&self) -> WindowHandle<V> {
self.handle
}
/// Get the root view entity.
pub fn root(&self) -> Entity<V> {
let mut app = self.app.borrow_mut();
let any_handle: AnyWindowHandle = self.handle.into();
app.update_window(any_handle, |root_view, _, _| {
root_view.downcast::<V>().expect("root view type mismatch")
})
.expect("window not found")
}
/// Update the root view.
/// Automatically draws the window after the update to ensure the scene is current.
pub fn update<R>(&mut self, f: impl FnOnce(&mut V, &mut Window, &mut Context<V>) -> R) -> R {
let result = {
let mut app = self.app.borrow_mut();
let any_handle: AnyWindowHandle = self.handle.into();
app.update_window(any_handle, |root_view, window, cx| {
let view = root_view.downcast::<V>().expect("root view type mismatch");
view.update(cx, |view, cx| f(view, window, cx))
})
.expect("window not found")
};
self.background_executor.run_until_parked();
self.draw();
result
}
/// Read the root view.
pub fn read<R>(&self, f: impl FnOnce(&V, &App) -> R) -> R {
let app = self.app.borrow();
let view = self
.app
.borrow()
.windows
.get(self.handle.window_id())
.and_then(|w| w.as_ref())
.and_then(|w| w.root.clone())
.and_then(|r| r.downcast::<V>().ok())
.expect("window or root view not found");
f(view.read(&app), &app)
}
/// Get the window title.
pub fn title(&self) -> Option<String> {
let app = self.app.borrow();
app.read_window(&self.handle, |_, _cx| {
// TODO: expose title through Window API
None
})
.unwrap()
}
/// Simulate a keystroke.
/// Automatically draws the window after the keystroke.
pub fn simulate_keystroke(&mut self, keystroke: &str) {
let keystroke = Keystroke::parse(keystroke).unwrap();
{
let mut app = self.app.borrow_mut();
let any_handle: AnyWindowHandle = self.handle.into();
app.update_window(any_handle, |_, window, cx| {
window.dispatch_keystroke(keystroke, cx);
})
.unwrap();
}
self.background_executor.run_until_parked();
self.draw();
}
/// Simulate multiple keystrokes (space-separated).
pub fn simulate_keystrokes(&mut self, keystrokes: &str) {
for keystroke in keystrokes.split(' ') {
self.simulate_keystroke(keystroke);
}
}
/// Simulate typing text.
pub fn simulate_input(&mut self, input: &str) {
for char in input.chars() {
self.simulate_keystroke(&char.to_string());
}
}
/// Simulate a mouse move.
pub fn simulate_mouse_move(&mut self, position: Point<Pixels>) {
self.simulate_event(MouseMoveEvent {
position,
modifiers: Default::default(),
pressed_button: None,
});
}
/// Simulate a mouse down event.
pub fn simulate_mouse_down(&mut self, position: Point<Pixels>, button: MouseButton) {
self.simulate_event(MouseDownEvent {
position,
button,
modifiers: Default::default(),
click_count: 1,
first_mouse: false,
});
}
/// Simulate a mouse up event.
pub fn simulate_mouse_up(&mut self, position: Point<Pixels>, button: MouseButton) {
self.simulate_event(MouseUpEvent {
position,
button,
modifiers: Default::default(),
click_count: 1,
});
}
/// Simulate a click at the given position.
pub fn simulate_click(&mut self, position: Point<Pixels>, button: MouseButton) {
self.simulate_mouse_down(position, button);
self.simulate_mouse_up(position, button);
}
/// Simulate a scroll event.
pub fn simulate_scroll(&mut self, position: Point<Pixels>, delta: Point<Pixels>) {
self.simulate_event(crate::ScrollWheelEvent {
position,
delta: crate::ScrollDelta::Pixels(delta),
modifiers: Default::default(),
touch_phase: crate::TouchPhase::Moved,
});
}
/// Simulate an input event.
/// Automatically draws the window after the event.
pub fn simulate_event<E: InputEvent>(&mut self, event: E) {
let platform_input = event.to_platform_input();
{
let mut app = self.app.borrow_mut();
let any_handle: AnyWindowHandle = self.handle.into();
app.update_window(any_handle, |_, window, cx| {
window.dispatch_event(platform_input, cx);
})
.unwrap();
}
self.background_executor.run_until_parked();
self.draw();
}
/// Simulate resizing the window.
/// Automatically draws the window after the resize.
pub fn simulate_resize(&mut self, size: Size<Pixels>) {
let window_id = self.handle.window_id();
let mut app = self.app.borrow_mut();
if let Some(Some(window)) = app.windows.get_mut(window_id) {
if let Some(test_window) = window.platform_window.as_test() {
test_window.simulate_resize(size);
}
}
drop(app);
self.background_executor.run_until_parked();
self.draw();
}
/// Force a redraw of the window.
pub fn draw(&mut self) {
let mut app = self.app.borrow_mut();
let any_handle: AnyWindowHandle = self.handle.into();
app.update_window(any_handle, |_, window, cx| {
window.draw(cx).clear();
})
.unwrap();
}
/// Get a snapshot of the rendered scene for inspection.
/// The scene is automatically kept up to date after `update()` and `simulate_*()` calls.
pub fn scene_snapshot(&self) -> SceneSnapshot {
let app = self.app.borrow();
let window = app
.windows
.get(self.handle.window_id())
.and_then(|w| w.as_ref())
.expect("window not found");
window.rendered_frame.scene.snapshot()
}
}
impl<V> Clone for TestWindow<V> {
fn clone(&self) -> Self {
Self {
handle: self.handle,
app: self.app.clone(),
platform: self.platform.clone(),
background_executor: self.background_executor.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{FocusHandle, Focusable, div, prelude::*};
struct Counter {
count: usize,
focus_handle: FocusHandle,
}
impl Counter {
fn new(_window: &mut Window, cx: &mut Context<Self>) -> Self {
let focus_handle = cx.focus_handle();
Self {
count: 0,
focus_handle,
}
}
fn increment(&mut self, _cx: &mut Context<Self>) {
self.count += 1;
}
}
impl Focusable for Counter {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl Render for Counter {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
div().child(format!("Count: {}", self.count))
}
}
#[test]
fn test_basic_usage() {
let mut app = TestApp::new();
let mut window = app.open_window(Counter::new);
window.update(|counter, _window, cx| {
counter.increment(cx);
});
window.read(|counter, _| {
assert_eq!(counter.count, 1);
});
}
#[test]
fn test_entity_creation() {
let mut app = TestApp::new();
let entity = app.new_entity(|cx| Counter {
count: 42,
focus_handle: cx.focus_handle(),
});
app.read_entity(&entity, |counter, _| {
assert_eq!(counter.count, 42);
});
app.update_entity(&entity, |counter, _cx| {
counter.count += 1;
});
app.read_entity(&entity, |counter, _| {
assert_eq!(counter.count, 43);
});
}
#[test]
fn test_globals() {
let mut app = TestApp::new();
struct MyGlobal(String);
impl Global for MyGlobal {}
assert!(!app.has_global::<MyGlobal>());
app.set_global(MyGlobal("hello".into()));
assert!(app.has_global::<MyGlobal>());
app.read_global::<MyGlobal, _>(|global, _| {
assert_eq!(global.0, "hello");
});
app.update_global::<MyGlobal, _>(|global, _| {
global.0 = "world".into();
});
app.read_global::<MyGlobal, _>(|global, _| {
assert_eq!(global.0, "world");
});
}
}

View File

@@ -3,9 +3,9 @@ use crate::{
BackgroundExecutor, BorrowAppContext, Bounds, Capslock, ClipboardItem, DrawPhase, Drawable,
Element, Empty, EventEmitter, ForegroundExecutor, Global, InputEvent, Keystroke, Modifiers,
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Pixels,
Platform, Point, Render, Result, Size, Task, TestDispatcher, TestPlatform,
TestScreenCaptureSource, TestWindow, TextSystem, VisualContext, Window, WindowBounds,
WindowHandle, WindowOptions, app::GpuiMode,
Platform, Point, Render, Result, Size, Task, TestDispatcher, TestPlatform, TestPlatformWindow,
TestScreenCaptureSource, TextSystem, VisualContext, Window, WindowBounds, WindowHandle,
WindowOptions, app::GpuiMode,
};
use anyhow::{anyhow, bail};
use futures::{Stream, StreamExt, channel::oneshot};
@@ -220,7 +220,7 @@ impl TestAppContext {
f(&cx)
}
/// Adds a new window. The Window will always be backed by a `TestWindow` which
/// Adds a new window. The Window will always be backed by a `TestPlatformWindow` which
/// can be retrieved with `self.test_window(handle)`
pub fn add_window<F, V>(&mut self, build_window: F) -> WindowHandle<V>
where
@@ -465,8 +465,8 @@ impl TestAppContext {
.unwrap();
}
/// Returns the `TestWindow` backing the given handle.
pub(crate) fn test_window(&self, window: AnyWindowHandle) -> TestWindow {
/// Returns the `TestPlatformWindow` backing the given handle.
pub(crate) fn test_window(&self, window: AnyWindowHandle) -> TestPlatformWindow {
self.app
.borrow_mut()
.windows

View File

@@ -5,14 +5,91 @@ use std::{
ops::{Add, Sub},
};
/// Maximum children per internal node (R-tree style branching factor).
/// Higher values = shorter tree = fewer cache misses, but more work per node.
const MAX_CHILDREN: usize = 12;
/// A spatial tree optimized for finding maximum ordering among intersecting bounds.
///
/// This is an R-tree variant specifically designed for the use case of assigning
/// z-order to overlapping UI elements. Key optimizations:
/// - Tracks the leaf with global max ordering for O(1) fast-path queries
/// - Uses higher branching factor (4) for lower tree height
/// - Aggressive pruning during search based on max_order metadata
#[derive(Debug)]
pub(crate) struct BoundsTree<U>
where
U: Clone + Debug + Default + PartialEq,
{
root: Option<usize>,
/// All nodes stored contiguously for cache efficiency.
nodes: Vec<Node<U>>,
stack: Vec<usize>,
/// Index of the root node, if any.
root: Option<usize>,
/// Index of the leaf with the highest ordering (for fast-path lookups).
max_leaf: Option<usize>,
/// Reusable stack for tree traversal during insertion.
insert_path: Vec<usize>,
/// Reusable stack for search operations.
search_stack: Vec<usize>,
}
/// A node in the bounds tree.
#[derive(Debug, Clone)]
struct Node<U>
where
U: Clone + Debug + Default + PartialEq,
{
/// Bounding box containing this node and all descendants.
bounds: Bounds<U>,
/// Maximum ordering value in this subtree.
max_order: u32,
/// Node-specific data.
kind: NodeKind,
}
#[derive(Debug, Clone)]
enum NodeKind {
/// Leaf node containing actual bounds data.
Leaf {
/// The ordering assigned to this bounds.
order: u32,
},
/// Internal node with children.
Internal {
/// Indices of child nodes (2 to MAX_CHILDREN).
children: NodeChildren,
},
}
/// Fixed-size array for child indices, avoiding heap allocation.
#[derive(Debug, Clone)]
struct NodeChildren {
// Keeps an invariant where the max order child is always at the end
indices: [usize; MAX_CHILDREN],
len: u8,
}
impl NodeChildren {
fn new() -> Self {
Self {
indices: [0; MAX_CHILDREN],
len: 0,
}
}
fn push(&mut self, index: usize) {
debug_assert!((self.len as usize) < MAX_CHILDREN);
self.indices[self.len as usize] = index;
self.len += 1;
}
fn len(&self) -> usize {
self.len as usize
}
fn as_slice(&self) -> &[usize] {
&self.indices[..self.len as usize]
}
}
impl<U> BoundsTree<U>
@@ -26,158 +103,250 @@ where
+ Half
+ Default,
{
/// Clears all nodes from the tree.
pub fn clear(&mut self) {
self.root = None;
self.nodes.clear();
self.stack.clear();
self.root = None;
self.max_leaf = None;
self.insert_path.clear();
self.search_stack.clear();
}
/// Inserts bounds into the tree and returns its assigned ordering.
///
/// The ordering is one greater than the maximum ordering of any
/// existing bounds that intersect with the new bounds.
pub fn insert(&mut self, new_bounds: Bounds<U>) -> u32 {
// If the tree is empty, make the root the new leaf.
let Some(mut index) = self.root else {
let new_node = self.push_leaf(new_bounds, 1);
self.root = Some(new_node);
return 1;
// Find maximum ordering among intersecting bounds
let max_intersecting = self.find_max_ordering(&new_bounds);
let ordering = max_intersecting + 1;
// Insert the new leaf
let new_leaf_idx = self.insert_leaf(new_bounds, ordering);
// Update max_leaf tracking
self.max_leaf = match self.max_leaf {
None => Some(new_leaf_idx),
Some(old_idx) if self.nodes[old_idx].max_order < ordering => Some(new_leaf_idx),
some => some,
};
// Search for the best place to add the new leaf based on heuristics.
let mut max_intersecting_ordering = 0;
while let Node::Internal {
left,
right,
bounds: node_bounds,
..
} = &mut self.nodes[index]
{
let left = *left;
let right = *right;
*node_bounds = node_bounds.union(&new_bounds);
self.stack.push(index);
// Descend to the best-fit child, based on which one would increase
// the surface area the least. This attempts to keep the tree balanced
// in terms of surface area. If there is an intersection with the other child,
// add its keys to the intersections vector.
let left_cost = new_bounds.union(self.nodes[left].bounds()).half_perimeter();
let right_cost = new_bounds
.union(self.nodes[right].bounds())
.half_perimeter();
if left_cost < right_cost {
max_intersecting_ordering =
self.find_max_ordering(right, &new_bounds, max_intersecting_ordering);
index = left;
} else {
max_intersecting_ordering =
self.find_max_ordering(left, &new_bounds, max_intersecting_ordering);
index = right;
}
}
// We've found a leaf ('index' now refers to a leaf node).
// We'll insert a new parent node above the leaf and attach our new leaf to it.
let sibling = index;
// Check for collision with the located leaf node
let Node::Leaf {
bounds: sibling_bounds,
order: sibling_ordering,
..
} = &self.nodes[index]
else {
unreachable!();
};
if sibling_bounds.intersects(&new_bounds) {
max_intersecting_ordering = cmp::max(max_intersecting_ordering, *sibling_ordering);
}
let ordering = max_intersecting_ordering + 1;
let new_node = self.push_leaf(new_bounds, ordering);
let new_parent = self.push_internal(sibling, new_node);
// If there was an old parent, we need to update its children indices.
if let Some(old_parent) = self.stack.last().copied() {
let Node::Internal { left, right, .. } = &mut self.nodes[old_parent] else {
unreachable!();
};
if *left == sibling {
*left = new_parent;
} else {
*right = new_parent;
}
} else {
// If the old parent was the root, the new parent is the new root.
self.root = Some(new_parent);
}
for node_index in self.stack.drain(..).rev() {
let Node::Internal {
max_order: max_ordering,
..
} = &mut self.nodes[node_index]
else {
unreachable!()
};
if *max_ordering >= ordering {
break;
}
*max_ordering = ordering;
}
ordering
}
fn find_max_ordering(&self, index: usize, bounds: &Bounds<U>, mut max_ordering: u32) -> u32 {
match &self.nodes[index] {
Node::Leaf {
bounds: node_bounds,
order: ordering,
..
} => {
if bounds.intersects(node_bounds) {
max_ordering = cmp::max(*ordering, max_ordering);
}
/// Finds the maximum ordering among all bounds that intersect with the query.
fn find_max_ordering(&mut self, query: &Bounds<U>) -> u32 {
let Some(root_idx) = self.root else {
return 0;
};
// Fast path: check if the max-ordering leaf intersects
if let Some(max_idx) = self.max_leaf {
let max_node = &self.nodes[max_idx];
if query.intersects(&max_node.bounds) {
return max_node.max_order;
}
Node::Internal {
left,
right,
bounds: node_bounds,
max_order: node_max_ordering,
..
} => {
if bounds.intersects(node_bounds) && max_ordering < *node_max_ordering {
let left_max_ordering = self.nodes[*left].max_ordering();
let right_max_ordering = self.nodes[*right].max_ordering();
if left_max_ordering > right_max_ordering {
max_ordering = self.find_max_ordering(*left, bounds, max_ordering);
max_ordering = self.find_max_ordering(*right, bounds, max_ordering);
} else {
max_ordering = self.find_max_ordering(*right, bounds, max_ordering);
max_ordering = self.find_max_ordering(*left, bounds, max_ordering);
}
// Slow path: search the tree
self.search_stack.clear();
self.search_stack.push(root_idx);
let mut max_found = 0u32;
while let Some(node_idx) = self.search_stack.pop() {
let node = &self.nodes[node_idx];
// Pruning: skip if this subtree can't improve our result
if node.max_order <= max_found {
continue;
}
// Spatial pruning: skip if bounds don't intersect
if !query.intersects(&node.bounds) {
continue;
}
match &node.kind {
NodeKind::Leaf { order } => {
max_found = cmp::max(max_found, *order);
}
NodeKind::Internal { children } => {
// Children are maintained with highest max_order at the end.
// Push in forward order to highest (last) is popped first.
for &child_idx in children.as_slice() {
if self.nodes[child_idx].max_order > max_found {
self.search_stack.push(child_idx);
}
}
}
}
}
max_ordering
max_found
}
fn push_leaf(&mut self, bounds: Bounds<U>, order: u32) -> usize {
self.nodes.push(Node::Leaf { bounds, order });
self.nodes.len() - 1
}
fn push_internal(&mut self, left: usize, right: usize) -> usize {
let left_node = &self.nodes[left];
let right_node = &self.nodes[right];
let new_bounds = left_node.bounds().union(right_node.bounds());
let max_ordering = cmp::max(left_node.max_ordering(), right_node.max_ordering());
self.nodes.push(Node::Internal {
bounds: new_bounds,
left,
right,
max_order: max_ordering,
/// Inserts a leaf node with the given bounds and ordering.
/// Returns the index of the new leaf.
fn insert_leaf(&mut self, bounds: Bounds<U>, order: u32) -> usize {
let new_leaf_idx = self.nodes.len();
self.nodes.push(Node {
bounds: bounds.clone(),
max_order: order,
kind: NodeKind::Leaf { order },
});
self.nodes.len() - 1
let Some(root_idx) = self.root else {
// Tree is empty, new leaf becomes root
self.root = Some(new_leaf_idx);
return new_leaf_idx;
};
// If root is a leaf, create internal node with both
if matches!(self.nodes[root_idx].kind, NodeKind::Leaf { .. }) {
let root_bounds = self.nodes[root_idx].bounds.clone();
let root_order = self.nodes[root_idx].max_order;
let mut children = NodeChildren::new();
// Max end invariant
if order > root_order {
children.push(root_idx);
children.push(new_leaf_idx);
} else {
children.push(new_leaf_idx);
children.push(root_idx);
}
let new_root_idx = self.nodes.len();
self.nodes.push(Node {
bounds: root_bounds.union(&bounds),
max_order: cmp::max(root_order, order),
kind: NodeKind::Internal { children },
});
self.root = Some(new_root_idx);
return new_leaf_idx;
}
// Descend to find the best internal node to insert into
self.insert_path.clear();
let mut current_idx = root_idx;
loop {
let current = &self.nodes[current_idx];
let NodeKind::Internal { children } = &current.kind else {
unreachable!("Should only traverse internal nodes");
};
self.insert_path.push(current_idx);
// Find the best child to descend into
let mut best_child_idx = children.as_slice()[0];
let mut best_child_pos = 0;
let mut best_cost = bounds
.union(&self.nodes[best_child_idx].bounds)
.half_perimeter();
for (pos, &child_idx) in children.as_slice().iter().enumerate().skip(1) {
let cost = bounds.union(&self.nodes[child_idx].bounds).half_perimeter();
if cost < best_cost {
best_cost = cost;
best_child_idx = child_idx;
best_child_pos = pos;
}
}
// Check if best child is a leaf or internal
if matches!(self.nodes[best_child_idx].kind, NodeKind::Leaf { .. }) {
// Best child is a leaf. Check if current node has room for another child.
if children.len() < MAX_CHILDREN {
// Add new leaf directly to this node
let node = &mut self.nodes[current_idx];
if let NodeKind::Internal { children } = &mut node.kind {
children.push(new_leaf_idx);
// Swap new leaf only if it has the highest max_order
if order <= node.max_order {
let last = children.len() - 1;
children.indices.swap(last - 1, last);
}
}
node.bounds = node.bounds.union(&bounds);
node.max_order = cmp::max(node.max_order, order);
break;
} else {
// Node is full, create new internal with [best_leaf, new_leaf]
let sibling_bounds = self.nodes[best_child_idx].bounds.clone();
let sibling_order = self.nodes[best_child_idx].max_order;
let mut new_children = NodeChildren::new();
// Max end invariant
if order > sibling_order {
new_children.push(best_child_idx);
new_children.push(new_leaf_idx);
} else {
new_children.push(new_leaf_idx);
new_children.push(best_child_idx);
}
let new_internal_idx = self.nodes.len();
let new_internal_max = cmp::max(sibling_order, order);
self.nodes.push(Node {
bounds: sibling_bounds.union(&bounds),
max_order: new_internal_max,
kind: NodeKind::Internal {
children: new_children,
},
});
// Replace the leaf with the new internal in parent
let parent = &mut self.nodes[current_idx];
if let NodeKind::Internal { children } = &mut parent.kind {
let children_len = children.len();
children.indices[best_child_pos] = new_internal_idx;
// If new internal has highest max_order, swap it to the end
// to maintain sorting invariant
if new_internal_max > parent.max_order {
children.indices.swap(best_child_pos, children_len - 1);
}
}
break;
}
} else {
// Best child is internal, continue descent
current_idx = best_child_idx;
}
}
// Propagate bounds and max_order updates up the tree
let mut updated_child_idx = None;
for &node_idx in self.insert_path.iter().rev() {
let node = &mut self.nodes[node_idx];
node.bounds = node.bounds.union(&bounds);
if node.max_order < order {
node.max_order = order;
// Swap updated child to end (skip first iteration since the invariant is already handled by previous cases)
if let Some(child_idx) = updated_child_idx {
if let NodeKind::Internal { children } = &mut node.kind {
if let Some(pos) = children.as_slice().iter().position(|&c| c == child_idx)
{
let last = children.len() - 1;
if pos != last {
children.indices.swap(pos, last);
}
}
}
}
}
updated_child_idx = Some(node_idx);
}
new_leaf_idx
}
}
@@ -187,50 +356,11 @@ where
{
fn default() -> Self {
BoundsTree {
root: None,
nodes: Vec::new(),
stack: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
enum Node<U>
where
U: Clone + Debug + Default + PartialEq,
{
Leaf {
bounds: Bounds<U>,
order: u32,
},
Internal {
left: usize,
right: usize,
bounds: Bounds<U>,
max_order: u32,
},
}
impl<U> Node<U>
where
U: Clone + Debug + Default + PartialEq,
{
fn bounds(&self) -> &Bounds<U> {
match self {
Node::Leaf { bounds, .. } => bounds,
Node::Internal { bounds, .. } => bounds,
}
}
fn max_ordering(&self) -> u32 {
match self {
Node::Leaf {
order: ordering, ..
} => *ordering,
Node::Internal {
max_order: max_ordering,
..
} => *max_ordering,
root: None,
max_leaf: None,
insert_path: Vec::new(),
search_stack: Vec::new(),
}
}
}

View File

@@ -808,6 +808,15 @@ impl LinearColorStop {
}
impl Background {
/// Returns the solid color if this is a solid background, None otherwise.
pub fn as_solid(&self) -> Option<Hsla> {
if self.tag == BackgroundTag::Solid {
Some(self.solid)
} else {
None
}
}
/// Use specified color space for color interpolation.
///
/// <https://developer.mozilla.org/en-US/docs/Web/CSS/color-interpolation-method>

View File

@@ -3193,7 +3193,11 @@ impl ScrollHandle {
match active_item.strategy {
ScrollStrategy::FirstVisible => {
if state.overflow.y == Overflow::Scroll {
if bounds.top() + scroll_offset.y < state.bounds.top() {
let child_height = bounds.size.height;
let viewport_height = state.bounds.size.height;
if child_height > viewport_height {
scroll_offset.y = state.bounds.top() - bounds.top();
} else if bounds.top() + scroll_offset.y < state.bounds.top() {
scroll_offset.y = state.bounds.top() - bounds.top();
} else if bounds.bottom() + scroll_offset.y > state.bounds.bottom() {
scroll_offset.y = state.bounds.bottom() - bounds.bottom();
@@ -3206,7 +3210,11 @@ impl ScrollHandle {
}
if state.overflow.x == Overflow::Scroll {
if bounds.left() + scroll_offset.x < state.bounds.left() {
let child_width = bounds.size.width;
let viewport_width = state.bounds.size.width;
if child_width > viewport_width {
scroll_offset.x = state.bounds.left() - bounds.left();
} else if bounds.left() + scroll_offset.x < state.bounds.left() {
scroll_offset.x = state.bounds.left() - bounds.left();
} else if bounds.right() + scroll_offset.x > state.bounds.right() {
scroll_offset.x = state.bounds.right() - bounds.right();
@@ -3268,3 +3276,46 @@ impl ScrollHandle {
self.0.borrow().child_bounds.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scroll_handle_aligns_wide_children_to_left_edge() {
let handle = ScrollHandle::new();
{
let mut state = handle.0.borrow_mut();
state.bounds = Bounds::new(point(px(0.), px(0.)), size(px(80.), px(20.)));
state.child_bounds = vec![Bounds::new(point(px(25.), px(0.)), size(px(200.), px(20.)))];
state.overflow.x = Overflow::Scroll;
state.active_item = Some(ScrollActiveItem {
index: 0,
strategy: ScrollStrategy::default(),
});
}
handle.scroll_to_active_item();
assert_eq!(handle.offset().x, px(-25.));
}
#[test]
fn scroll_handle_aligns_tall_children_to_top_edge() {
let handle = ScrollHandle::new();
{
let mut state = handle.0.borrow_mut();
state.bounds = Bounds::new(point(px(0.), px(0.)), size(px(20.), px(80.)));
state.child_bounds = vec![Bounds::new(point(px(0.), px(25.)), size(px(20.), px(200.)))];
state.overflow.y = Overflow::Scroll;
state.active_item = Some(ScrollActiveItem {
index: 0,
strategy: ScrollStrategy::default(),
});
}
handle.scroll_to_active_item();
assert_eq!(handle.offset().y, px(-25.));
}
}

View File

@@ -1,4 +1,4 @@
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
use async_task::Runnable;
use futures::channel::mpsc;
use parking_lot::{Condvar, Mutex};
@@ -47,6 +47,52 @@ pub struct ForegroundExecutor {
not_send: PhantomData<Rc<()>>,
}
/// Realtime task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum RealtimePriority {
/// Audio task
Audio,
/// Other realtime task
#[default]
Other,
}
/// Task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum Priority {
/// Realtime priority
///
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
Realtime(RealtimePriority),
/// High priority
///
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
High,
/// Medium priority, probably suits most of your use cases.
#[default]
Medium,
/// Low priority
///
/// Prioritize this for background work that can come in large quantities
/// to not starve the executor of resources for high priority tasks
Low,
}
impl Priority {
#[allow(dead_code)]
pub(crate) const fn probability(&self) -> u32 {
match self {
// realtime priorities are not considered for probability scheduling
Priority::Realtime(_) => 0,
Priority::High => 60,
Priority::Medium => 30,
Priority::Low => 10,
}
}
}
/// Task is a primitive that allows work to happen in the background.
///
/// It implements [`Future`] so you can `.await` on it.
@@ -152,7 +198,20 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None)
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None, priority)
}
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
@@ -199,7 +258,13 @@ impl BackgroundExecutor {
let _notify_guard = NotifyOnDrop(pair);
future.await
},
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
move |runnable| {
dispatcher.dispatch(
RunnableVariant::Meta(runnable),
None,
Priority::default(),
)
},
)
};
runnable.schedule();
@@ -217,7 +282,7 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), Some(label))
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
}
#[track_caller]
@@ -225,15 +290,55 @@ impl BackgroundExecutor {
&self,
future: AnyFuture<R>,
label: Option<TaskLabel>,
priority: Priority,
) -> Task<R> {
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let (runnable, task) = async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
let location = core::panic::Location::caller();
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
dispatcher.spawn_realtime(
realtime,
Box::new(move || {
while let Ok(runnable) = rx.recv() {
let start = Instant::now();
let location = runnable.metadata().location;
let mut timing = TaskTiming {
location,
start,
end: None,
};
profiler::add_task_timing(timing);
runnable.run();
let end = Instant::now();
timing.end = Some(end);
profiler::add_task_timing(timing);
}
}),
);
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
let _ = tx.send(runnable);
},
)
} else {
let location = core::panic::Location::caller();
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
},
)
};
runnable.schedule();
Task(TaskState::Spawned(task))
}
@@ -406,11 +511,28 @@ impl BackgroundExecutor {
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone());
let mut scope = Scope::new(self.clone(), Priority::default());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn(f))
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
/// Scoped lets you start a number of tasks and waits
/// for all of them to complete before returning.
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone(), priority);
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
@@ -546,6 +668,19 @@ impl ForegroundExecutor {
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
where
R: 'static,
{
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + 'static,
) -> Task<R>
where
R: 'static,
{
@@ -557,16 +692,19 @@ impl ForegroundExecutor {
dispatcher: Arc<dyn PlatformDispatcher>,
future: AnyLocalFuture<R>,
location: &'static core::panic::Location<'static>,
priority: Priority,
) -> Task<R> {
let (runnable, task) = spawn_local_with_source_location(
future,
move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
move |runnable| {
dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
},
RunnableMeta { location },
);
runnable.schedule();
Task(TaskState::Spawned(task))
}
inner::<R>(dispatcher, Box::pin(future), location)
inner::<R>(dispatcher, Box::pin(future), location, priority)
}
}
@@ -642,6 +780,7 @@ where
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
pub struct Scope<'a> {
executor: BackgroundExecutor,
priority: Priority,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
@@ -649,10 +788,11 @@ pub struct Scope<'a> {
}
impl<'a> Scope<'a> {
fn new(executor: BackgroundExecutor) -> Self {
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
priority,
tx: Some(tx),
rx,
futures: Default::default(),

View File

@@ -1416,9 +1416,9 @@ where
/// ```
pub fn contains(&self, point: &Point<T>) -> bool {
point.x >= self.origin.x
&& point.x <= self.origin.x.clone() + self.size.width.clone()
&& point.x < self.origin.x.clone() + self.size.width.clone()
&& point.y >= self.origin.y
&& point.y <= self.origin.y.clone() + self.size.height.clone()
&& point.y < self.origin.y.clone() + self.size.height.clone()
}
/// Checks if this bounds is completely contained within another bounds.

View File

@@ -31,6 +31,8 @@ mod path_builder;
mod platform;
pub mod prelude;
mod profiler;
#[cfg(any(target_os = "windows", target_os = "linux"))]
mod queue;
mod scene;
mod shared_string;
mod shared_uri;
@@ -89,16 +91,20 @@ pub use keymap::*;
pub use path_builder::*;
pub use platform::*;
pub use profiler::*;
#[cfg(any(target_os = "windows", target_os = "linux"))]
pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
pub use refineable::*;
pub use scene::*;
pub use shared_string::*;
pub use shared_uri::*;
pub use smol::Timer;
use std::{any::Any, future::Future};
pub use style::*;
pub use styled::*;
pub use subscription::*;
pub use svg_renderer::*;
pub(crate) use tab_stop::*;
use taffy::TaffyLayoutEngine;
pub use taffy::{AvailableSpace, LayoutId};
#[cfg(any(test, feature = "test-support"))]
pub use test::*;
@@ -109,9 +115,6 @@ pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
pub use view::*;
pub use window::*;
use std::{any::Any, future::Future};
use taffy::TaffyLayoutEngine;
/// The context trait, allows the different contexts in GPUI to be used
/// interchangeably for certain operations.
pub trait AppContext {

View File

@@ -39,9 +39,10 @@ use crate::{
Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds,
DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun,
ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput,
Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph,
ShapedRun, SharedString, Size, SvgRenderer, SystemWindowTab, Task, TaskLabel, TaskTiming,
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
Point, Priority, RealtimePriority, RenderGlyphParams, RenderImage, RenderImageParams,
RenderSvgParams, Scene, ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer,
SystemWindowTab, Task, TaskLabel, TaskTiming, ThreadTaskTimings, Window, WindowControlArea,
hash, point, px, size,
};
use anyhow::Result;
use async_task::Runnable;
@@ -560,7 +561,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn update_ime_position(&self, _bounds: Bounds<Pixels>);
#[cfg(any(test, feature = "test-support"))]
fn as_test(&mut self) -> Option<&mut TestWindow> {
fn as_test(&mut self) -> Option<&mut TestPlatformWindow> {
None
}
}
@@ -587,9 +588,10 @@ pub trait PlatformDispatcher: Send + Sync {
fn get_all_timings(&self) -> Vec<ThreadTaskTimings>;
fn get_current_thread_timings(&self) -> Vec<TaskTiming>;
fn is_main_thread(&self) -> bool;
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>);
fn now(&self) -> Instant {
Instant::now()

View File

@@ -1,9 +1,10 @@
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableVariant, THREAD_TIMINGS, TaskLabel,
TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueReceiver,
PriorityQueueSender, RealtimePriority, RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, profiler,
};
use calloop::{
EventLoop,
EventLoop, PostAction,
channel::{self, Sender},
timer::TimeoutAction,
};
@@ -19,9 +20,9 @@ struct TimerAfter {
}
pub(crate) struct LinuxDispatcher {
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueCalloopSender<RunnableVariant>,
timer_sender: Sender<TimerAfter>,
background_sender: flume::Sender<RunnableVariant>,
background_sender: PriorityQueueSender<RunnableVariant>,
_background_threads: Vec<thread::JoinHandle<()>>,
main_thread_id: thread::ThreadId,
}
@@ -29,18 +30,20 @@ pub(crate) struct LinuxDispatcher {
const MIN_THREADS: usize = 2;
impl LinuxDispatcher {
pub fn new(main_sender: Sender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = flume::unbounded::<RunnableVariant>();
pub fn new(main_sender: PriorityQueueCalloopSender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
let thread_count =
std::thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
// These thread should really be lower prio then the foreground
// executor
let mut background_threads = (0..thread_count)
.map(|i| {
let receiver = background_receiver.clone();
let mut receiver = background_receiver.clone();
std::thread::Builder::new()
.name(format!("Worker-{i}"))
.spawn(move || {
for runnable in receiver {
for runnable in receiver.iter() {
let start = Instant::now();
let mut location = match runnable {
@@ -51,7 +54,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -63,7 +66,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -72,7 +75,7 @@ impl LinuxDispatcher {
let end = Instant::now();
location.end = Some(end);
Self::add_task_timing(location);
profiler::add_task_timing(location);
log::trace!(
"background thread {}: ran runnable. took: {:?}",
@@ -113,7 +116,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -124,7 +127,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -133,7 +136,7 @@ impl LinuxDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
}
TimeoutAction::Drop
},
@@ -157,22 +160,6 @@ impl LinuxDispatcher {
main_thread_id: thread::current().id(),
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}
}
impl PlatformDispatcher for LinuxDispatcher {
@@ -199,22 +186,26 @@ impl PlatformDispatcher for LinuxDispatcher {
thread::current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
self.background_sender.send(runnable).unwrap();
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
self.background_sender
.send(priority, runnable)
.unwrap_or_else(|_| panic!("blocking sender returned without value"));
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
self.main_sender.send(runnable).unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
self.main_sender
.send(priority, runnable)
.unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
}
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
@@ -222,4 +213,252 @@ impl PlatformDispatcher for LinuxDispatcher {
.send(TimerAfter { duration, runnable })
.ok();
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
let policy = match priority {
RealtimePriority::Audio => libc::SCHED_FIFO,
RealtimePriority::Other => libc::SCHED_RR,
};
let sched_priority = match priority {
RealtimePriority::Audio => 65,
RealtimePriority::Other => 45,
};
let sched_param = libc::sched_param { sched_priority };
// SAFETY: sched_param is a valid initialized structure
let result = unsafe { libc::pthread_setschedparam(thread_id, policy, &sched_param) };
if result != 0 {
log::warn!("failed to set realtime thread priority to {:?}", priority);
}
f();
});
}
}
pub struct PriorityQueueCalloopSender<T> {
sender: PriorityQueueSender<T>,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopSender<T> {
fn new(tx: PriorityQueueSender<T>, ping: calloop::ping::Ping) -> Self {
Self { sender: tx, ping }
}
fn send(&self, priority: Priority, item: T) -> Result<(), crate::queue::SendError<T>> {
let res = self.sender.send(priority, item);
if res.is_ok() {
self.ping.ping();
}
res
}
}
impl<T> Drop for PriorityQueueCalloopSender<T> {
fn drop(&mut self) {
self.ping.ping();
}
}
pub struct PriorityQueueCalloopReceiver<T> {
receiver: PriorityQueueReceiver<T>,
source: calloop::ping::PingSource,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopReceiver<T> {
pub fn new() -> (PriorityQueueCalloopSender<T>, Self) {
let (ping, source) = calloop::ping::make_ping().expect("Failed to create a Ping.");
let (tx, rx) = PriorityQueueReceiver::new();
(
PriorityQueueCalloopSender::new(tx, ping.clone()),
Self {
receiver: rx,
source,
ping,
},
)
}
}
use calloop::channel::Event;
#[derive(Debug)]
pub struct ChannelError(calloop::ping::PingError);
impl std::fmt::Display for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
impl<T> calloop::EventSource for PriorityQueueCalloopReceiver<T> {
type Event = Event<T>;
type Metadata = ();
type Ret = ();
type Error = ChannelError;
fn process_events<F>(
&mut self,
readiness: calloop::Readiness,
token: calloop::Token,
mut callback: F,
) -> Result<calloop::PostAction, Self::Error>
where
F: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
{
let mut clear_readiness = false;
let mut disconnected = false;
let action = self
.source
.process_events(readiness, token, |(), &mut ()| {
let mut is_empty = true;
let mut receiver = self.receiver.clone();
for runnable in receiver.try_iter() {
match runnable {
Ok(r) => {
callback(Event::Msg(r), &mut ());
is_empty = false;
}
Err(_) => {
disconnected = true;
}
}
}
if disconnected {
callback(Event::Closed, &mut ());
}
if is_empty {
clear_readiness = true;
}
})
.map_err(ChannelError)?;
if disconnected {
Ok(PostAction::Remove)
} else if clear_readiness {
Ok(action)
} else {
// Re-notify the ping source so we can try again.
self.ping.ping();
Ok(PostAction::Continue)
}
}
fn register(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.register(poll, token_factory)
}
fn reregister(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.reregister(poll, token_factory)
}
fn unregister(&mut self, poll: &mut calloop::Poll) -> calloop::Result<()> {
self.source.unregister(poll)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calloop_works() {
let mut event_loop = calloop::EventLoop::try_new().unwrap();
let handle = event_loop.handle();
let (tx, rx) = PriorityQueueCalloopReceiver::new();
struct Data {
got_msg: bool,
got_closed: bool,
}
let mut data = Data {
got_msg: false,
got_closed: false,
};
let _channel_token = handle
.insert_source(rx, move |evt, &mut (), data: &mut Data| match evt {
Event::Msg(()) => {
data.got_msg = true;
}
Event::Closed => {
data.got_closed = true;
}
})
.unwrap();
// nothing is sent, nothing is received
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(!data.got_msg);
assert!(!data.got_closed);
// a message is send
tx.send(Priority::Medium, ()).unwrap();
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(!data.got_closed);
// the sender is dropped
drop(tx);
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(data.got_closed);
}
}
// running 1 test
// test platform::linux::dispatcher::tests::tomato ... FAILED
// failures:
// ---- platform::linux::dispatcher::tests::tomato stdout ----
// [crates/gpui/src/platform/linux/dispatcher.rs:262:9]
// returning 1 tasks to process
// [crates/gpui/src/platform/linux/dispatcher.rs:480:75] evt = Msg(
// (),
// )
// returning 0 tasks to process
// thread 'platform::linux::dispatcher::tests::tomato' (478301) panicked at crates/gpui/src/platform/linux/dispatcher.rs:515:9:
// assertion failed: data.got_closed
// note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

View File

@@ -14,7 +14,7 @@ use std::{
};
use anyhow::{Context as _, anyhow};
use calloop::{LoopSignal, channel::Channel};
use calloop::LoopSignal;
use futures::channel::oneshot;
use util::ResultExt as _;
use util::command::{new_smol_command, new_std_command};
@@ -25,8 +25,8 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, RunnableVariant, Task, WindowAppearance,
WindowParams, px,
PlatformTextSystem, PlatformWindow, Point, PriorityQueueCalloopReceiver, Result,
RunnableVariant, Task, WindowAppearance, WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@@ -149,8 +149,8 @@ pub(crate) struct LinuxCommon {
}
impl LinuxCommon {
pub fn new(signal: LoopSignal) -> (Self, Channel<RunnableVariant>) {
let (main_sender, main_receiver) = calloop::channel::channel::<RunnableVariant>();
pub fn new(signal: LoopSignal) -> (Self, PriorityQueueCalloopReceiver<RunnableVariant>) {
let (main_sender, main_receiver) = PriorityQueueCalloopReceiver::new();
#[cfg(any(feature = "wayland", feature = "x11"))]
let text_system = Arc::new(crate::CosmicTextSystem::new());

View File

@@ -77,10 +77,10 @@ use crate::{
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, profiler, px, size,
};
use crate::{
LinuxDispatcher, RunnableVariant, TaskTiming,
RunnableVariant, TaskTiming,
platform::{PlatformWindow, blade::BladeContext},
};
use crate::{
@@ -503,7 +503,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -515,7 +515,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -524,7 +524,7 @@ impl WaylandClient {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -1,4 +1,4 @@
use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
use crate::{Capslock, ResultExt as _, RunnableVariant, TaskTiming, profiler, xcb_flush};
use anyhow::{Context as _, anyhow};
use ashpd::WindowIdentifier;
use calloop::{
@@ -322,7 +322,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -334,7 +334,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -343,7 +343,7 @@ impl X11Client {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -3,11 +3,22 @@
#![allow(non_snake_case)]
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableMeta, RunnableVariant, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, RealtimePriority, RunnableMeta,
RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming, ThreadTaskTimings,
};
use anyhow::Context;
use async_task::Runnable;
use mach2::{
kern_return::KERN_SUCCESS,
mach_time::mach_timebase_info_data_t,
thread_policy::{
THREAD_EXTENDED_POLICY, THREAD_EXTENDED_POLICY_COUNT, THREAD_PRECEDENCE_POLICY,
THREAD_PRECEDENCE_POLICY_COUNT, THREAD_TIME_CONSTRAINT_POLICY,
THREAD_TIME_CONSTRAINT_POLICY_COUNT, thread_extended_policy_data_t,
thread_precedence_policy_data_t, thread_time_constraint_policy_data_t,
},
};
use objc::{
class, msg_send,
runtime::{BOOL, YES},
@@ -15,9 +26,11 @@ use objc::{
};
use std::{
ffi::c_void,
mem::MaybeUninit,
ptr::{NonNull, addr_of},
time::{Duration, Instant},
};
use util::ResultExt;
/// All items in the generated file are marked as pub, so we're gonna wrap it in a separate mod to prevent
/// these pub items from leaking into public API.
@@ -56,7 +69,7 @@ impl PlatformDispatcher for MacDispatcher {
is_main_thread == YES
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -67,16 +80,24 @@ impl PlatformDispatcher for MacDispatcher {
Some(trampoline_compat as unsafe extern "C" fn(*mut c_void)),
),
};
let queue_priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => DISPATCH_QUEUE_PRIORITY_HIGH as isize,
Priority::Medium => DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
Priority::Low => DISPATCH_QUEUE_PRIORITY_LOW as isize,
};
unsafe {
dispatch_async_f(
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH.try_into().unwrap(), 0),
dispatch_get_global_queue(queue_priority, 0),
context,
trampoline,
);
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -110,6 +131,120 @@ impl PlatformDispatcher for MacDispatcher {
dispatch_after_f(when, queue, context, trampoline);
}
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
match priority {
RealtimePriority::Audio => set_audio_thread_priority(),
RealtimePriority::Other => set_high_thread_priority(),
}
.context(format!("for priority {:?}", priority))
.log_err();
f();
});
}
}
fn set_high_thread_priority() -> anyhow::Result<()> {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: all sched_param members are valid when initialized to zero.
let mut sched_param = unsafe { MaybeUninit::<libc::sched_param>::zeroed().assume_init() };
sched_param.sched_priority = 45;
let result = unsafe { libc::pthread_setschedparam(thread_id, libc::SCHED_FIFO, &sched_param) };
if result != 0 {
anyhow::bail!("failed to set realtime thread priority")
}
Ok(())
}
fn set_audio_thread_priority() -> anyhow::Result<()> {
// https://chromium.googlesource.com/chromium/chromium/+/master/base/threading/platform_thread_mac.mm#93
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: thread_id is a valid thread id
let thread_id = unsafe { libc::pthread_mach_thread_np(thread_id) };
// Fixed priority thread
let mut policy = thread_extended_policy_data_t { timeshare: 0 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_extended_policy_data_t is passed as THREAD_EXTENDED_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_EXTENDED_POLICY,
&mut policy as *mut _ as *mut _,
THREAD_EXTENDED_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread extended policy");
}
// relatively high priority
let mut precedence = thread_precedence_policy_data_t { importance: 63 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_policy_data_t is passed as THREAD_PRECEDENCE_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_PRECEDENCE_POLICY,
&mut precedence as *mut _ as *mut _,
THREAD_PRECEDENCE_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread precedence policy");
}
const GUARANTEED_AUDIO_DUTY_CYCLE: f32 = 0.75;
const MAX_AUDIO_DUTY_CYCLE: f32 = 0.85;
// ~128 frames @ 44.1KHz
const TIME_QUANTUM: f32 = 2.9;
const AUDIO_TIME_NEEDED: f32 = GUARANTEED_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
const MAX_TIME_ALLOWED: f32 = MAX_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
let mut timebase_info = mach_timebase_info_data_t { numer: 0, denom: 0 };
// SAFETY: timebase_info is a valid pointer to a mach_timebase_info_data_t struct
unsafe { mach2::mach_time::mach_timebase_info(&mut timebase_info) };
let ms_to_abs_time = ((timebase_info.denom as f32) / (timebase_info.numer as f32)) * 1000000f32;
let mut time_constraints = thread_time_constraint_policy_data_t {
period: (TIME_QUANTUM * ms_to_abs_time) as u32,
computation: (AUDIO_TIME_NEEDED * ms_to_abs_time) as u32,
constraint: (MAX_TIME_ALLOWED * ms_to_abs_time) as u32,
preemptible: 0,
};
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_pthread_time_constraint_policy_data_t is passed as THREAD_TIME_CONSTRAINT_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_TIME_CONSTRAINT_POLICY,
&mut time_constraints as *mut _ as *mut _,
THREAD_TIME_CONSTRAINT_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread time constraint policy");
}
Ok(())
}
extern "C" fn trampoline(runnable: *mut c_void) {

View File

@@ -1,4 +1,4 @@
use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
use backtrace::Backtrace;
use collections::{HashMap, HashSet, VecDeque};
use parking::Unparker;
@@ -284,7 +284,7 @@ impl PlatformDispatcher for TestDispatcher {
state.start_time + state.time
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
{
let mut state = self.state.lock();
if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
@@ -296,7 +296,7 @@ impl PlatformDispatcher for TestDispatcher {
self.unpark_all();
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
self.state
.lock()
.foreground
@@ -318,4 +318,10 @@ impl PlatformDispatcher for TestDispatcher {
fn as_test(&self) -> Option<&TestDispatcher> {
Some(self)
}
fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
f();
});
}
}

View File

@@ -3,7 +3,7 @@ use crate::{
DummyKeyboardMapper, ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay,
PlatformKeyboardLayout, PlatformKeyboardMapper, PlatformTextSystem, PromptButton,
ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, SourceMetadata, Task,
TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
TestDisplay, TestPlatformWindow, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
@@ -26,7 +26,7 @@ pub(crate) struct TestPlatform {
background_executor: BackgroundExecutor,
foreground_executor: ForegroundExecutor,
pub(crate) active_window: RefCell<Option<TestWindow>>,
pub(crate) active_window: RefCell<Option<TestPlatformWindow>>,
active_display: Rc<dyn PlatformDisplay>,
active_cursor: Mutex<CursorStyle>,
current_clipboard_item: Mutex<Option<ClipboardItem>>,
@@ -196,7 +196,7 @@ impl TestPlatform {
rx
}
pub(crate) fn set_active_window(&self, window: Option<TestWindow>) {
pub(crate) fn set_active_window(&self, window: Option<TestPlatformWindow>) {
let executor = self.foreground_executor();
let previous_window = self.active_window.borrow_mut().take();
self.active_window.borrow_mut().clone_from(&window);
@@ -314,7 +314,7 @@ impl Platform for TestPlatform {
handle: AnyWindowHandle,
params: WindowParams,
) -> anyhow::Result<Box<dyn crate::PlatformWindow>> {
let window = TestWindow::new(
let window = TestPlatformWindow::new(
handle,
params,
self.weak.clone(),

View File

@@ -12,7 +12,7 @@ use std::{
sync::{self, Arc},
};
pub(crate) struct TestWindowState {
pub(crate) struct TestPlatformWindowState {
pub(crate) bounds: Bounds<Pixels>,
pub(crate) handle: AnyWindowHandle,
display: Rc<dyn PlatformDisplay>,
@@ -32,9 +32,9 @@ pub(crate) struct TestWindowState {
}
#[derive(Clone)]
pub(crate) struct TestWindow(pub(crate) Rc<Mutex<TestWindowState>>);
pub(crate) struct TestPlatformWindow(pub(crate) Rc<Mutex<TestPlatformWindowState>>);
impl HasWindowHandle for TestWindow {
impl HasWindowHandle for TestPlatformWindow {
fn window_handle(
&self,
) -> Result<raw_window_handle::WindowHandle<'_>, raw_window_handle::HandleError> {
@@ -42,7 +42,7 @@ impl HasWindowHandle for TestWindow {
}
}
impl HasDisplayHandle for TestWindow {
impl HasDisplayHandle for TestPlatformWindow {
fn display_handle(
&self,
) -> Result<raw_window_handle::DisplayHandle<'_>, raw_window_handle::HandleError> {
@@ -50,14 +50,14 @@ impl HasDisplayHandle for TestWindow {
}
}
impl TestWindow {
impl TestPlatformWindow {
pub fn new(
handle: AnyWindowHandle,
params: WindowParams,
platform: Weak<TestPlatform>,
display: Rc<dyn PlatformDisplay>,
) -> Self {
Self(Rc::new(Mutex::new(TestWindowState {
Self(Rc::new(Mutex::new(TestPlatformWindowState {
bounds: params.bounds,
display,
platform,
@@ -111,7 +111,7 @@ impl TestWindow {
}
}
impl PlatformWindow for TestWindow {
impl PlatformWindow for TestPlatformWindow {
fn bounds(&self) -> Bounds<Pixels> {
self.0.lock().bounds
}
@@ -272,7 +272,7 @@ impl PlatformWindow for TestWindow {
self.0.lock().sprite_atlas.clone()
}
fn as_test(&mut self) -> Option<&mut TestWindow> {
fn as_test(&mut self) -> Option<&mut TestPlatformWindow> {
Some(self)
}

View File

@@ -4,24 +4,31 @@ use std::{
time::{Duration, Instant},
};
use flume::Sender;
use anyhow::Context;
use util::ResultExt;
use windows::{
System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
System::Threading::{
GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
},
UI::WindowsAndMessaging::PostMessageW,
},
};
use crate::{
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
};
pub(crate) struct WindowsDispatcher {
pub(crate) wake_posted: AtomicBool,
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
main_thread_id: ThreadId,
pub(crate) platform_window_handle: SafeHwnd,
validation_number: usize,
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
impl WindowsDispatcher {
pub(crate) fn new(
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
platform_window_handle: HWND,
validation_number: usize,
) -> Self {
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
}
}
fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
let handler = {
let mut task_wrapper = Some(runnable);
WorkItemHandler::new(move |_| {
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPool::RunAsync(&handler).log_err();
ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
profiler::add_task_timing(timing);
}
}
@@ -146,15 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
self.dispatch_on_threadpool(runnable);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
let priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => WorkItemPriority::High,
Priority::Medium => WorkItemPriority::Normal,
Priority::Low => WorkItemPriority::Low,
};
self.dispatch_on_threadpool(priority, runnable);
if let Some(label) = label {
log::debug!("TaskLabel: {label:?}");
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
match self.main_sender.send(runnable) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
match self.main_sender.send(priority, runnable) {
Ok(_) => {
if !self.wake_posted.swap(true, Ordering::AcqRel) {
unsafe {
@@ -185,4 +184,28 @@ impl PlatformDispatcher for WindowsDispatcher {
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
self.dispatch_on_threadpool_after(runnable, duration);
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_handle = unsafe { GetCurrentThread() };
let thread_priority = match priority {
RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
};
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
.context("thread priority class")
.log_err();
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetThreadPriority(thread_handle, thread_priority) }
.context("thread priority")
.log_err();
f();
});
}
}

View File

@@ -243,7 +243,8 @@ impl WindowsWindowInner {
fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
for runnable in self.main_receiver.drain() {
let mut runnables = self.main_receiver.clone().try_iter();
while let Some(Ok(runnable)) = runnables.next() {
WindowsDispatcher::execute_runnable(runnable);
}
self.handle_paint_msg(handle)

View File

@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
// The below members will never change throughout the entire lifecycle of the app.
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
dispatcher: Arc<WindowsDispatcher>,
}
@@ -98,7 +98,7 @@ impl WindowsPlatform {
OleInitialize(None).context("unable to initialize Windows OLE")?;
}
let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
let validation_number = if usize::BITS == 64 {
rand::random::<u64>() as usize
} else {
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
}
break 'tasks;
}
match self.main_receiver.try_recv() {
Err(_) => break 'timeout_loop,
Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
_ => break 'timeout_loop,
}
}
// Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
// We need to check for those Runnables after we clear the flag.
self.dispatcher.wake_posted.store(false, Ordering::Release);
match self.main_receiver.try_recv() {
Err(_) => break 'tasks,
Ok(runnable) => {
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => {
self.dispatcher.wake_posted.store(true, Ordering::Release);
WindowsDispatcher::execute_runnable(runnable);
}
_ => break 'tasks,
}
}
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
pub(crate) windows_version: WindowsVersion,
pub(crate) drop_target_helper: IDropTargetHelper,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
pub(crate) disable_direct_composition: bool,
pub(crate) directx_devices: DirectXDevices,
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
inner: Option<Result<Rc<WindowsPlatformInner>>>,
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
validation_number: usize,
main_sender: Option<flume::Sender<RunnableVariant>>,
main_receiver: Option<flume::Receiver<RunnableVariant>>,
main_sender: Option<PriorityQueueSender<RunnableVariant>>,
main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
directx_devices: Option<DirectXDevices>,
dispatcher: Option<Arc<WindowsDispatcher>>,
}

View File

@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
pub(crate) executor: ForegroundExecutor,
pub(crate) windows_version: WindowsVersion,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
}
@@ -362,7 +362,7 @@ struct WindowCreateContext {
windows_version: WindowsVersion,
drop_target_helper: IDropTargetHelper,
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
platform_window_handle: HWND,
appearance: WindowAppearance,
disable_direct_composition: bool,

View File

@@ -216,3 +216,19 @@ impl Drop for ThreadTimings {
thread_timings.swap_remove(index);
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}

329
crates/gpui/src/queue.rs Normal file
View File

@@ -0,0 +1,329 @@
use std::{
fmt,
iter::FusedIterator,
sync::{Arc, atomic::AtomicUsize},
};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use crate::Priority;
struct PriorityQueues<T> {
high_priority: Vec<T>,
medium_priority: Vec<T>,
low_priority: Vec<T>,
}
impl<T> PriorityQueues<T> {
fn is_empty(&self) -> bool {
self.high_priority.is_empty()
&& self.medium_priority.is_empty()
&& self.low_priority.is_empty()
}
}
struct PriorityQueueState<T> {
queues: parking_lot::Mutex<PriorityQueues<T>>,
condvar: parking_lot::Condvar,
receiver_count: AtomicUsize,
sender_count: AtomicUsize,
}
impl<T> PriorityQueueState<T> {
fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
if self
.receiver_count
.load(std::sync::atomic::Ordering::Relaxed)
== 0
{
return Err(SendError(item));
}
let mut queues = self.queues.lock();
match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => queues.high_priority.push(item),
Priority::Medium => queues.medium_priority.push(item),
Priority::Low => queues.low_priority.push(item),
};
self.condvar.notify_one();
Ok(())
}
fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
// parking_lot doesn't do spurious wakeups so an if is fine
if queues.is_empty() {
self.condvar.wait(&mut queues);
}
Ok(queues)
}
fn try_recv<'a>(
&'a self,
) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
if queues.is_empty() {
Ok(None)
} else {
Ok(Some(queues))
}
}
}
pub(crate) struct PriorityQueueSender<T> {
state: Arc<PriorityQueueState<T>>,
}
impl<T> PriorityQueueSender<T> {
fn new(state: Arc<PriorityQueueState<T>>) -> Self {
Self { state }
}
pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
self.state.send(priority, item)?;
Ok(())
}
}
impl<T> Drop for PriorityQueueSender<T> {
fn drop(&mut self) {
self.state
.sender_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
pub(crate) struct PriorityQueueReceiver<T> {
state: Arc<PriorityQueueState<T>>,
rand: SmallRng,
disconnected: bool,
}
impl<T> Clone for PriorityQueueReceiver<T> {
fn clone(&self) -> Self {
self.state
.receiver_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
Self {
state: Arc::clone(&self.state),
rand: SmallRng::seed_from_u64(0),
disconnected: self.disconnected,
}
}
}
pub(crate) struct SendError<T>(T);
impl<T: fmt::Debug> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SendError").field(&self.0).finish()
}
}
#[derive(Debug)]
pub(crate) struct RecvError;
#[allow(dead_code)]
impl<T> PriorityQueueReceiver<T> {
pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
let state = PriorityQueueState {
queues: parking_lot::Mutex::new(PriorityQueues {
high_priority: Vec::new(),
medium_priority: Vec::new(),
low_priority: Vec::new(),
}),
condvar: parking_lot::Condvar::new(),
receiver_count: AtomicUsize::new(1),
sender_count: AtomicUsize::new(1),
};
let state = Arc::new(state);
let sender = PriorityQueueSender::new(Arc::clone(&state));
let receiver = PriorityQueueReceiver {
state,
rand: SmallRng::seed_from_u64(0),
disconnected: false,
};
(sender, receiver)
}
/// Tries to pop one element from the priority queue without blocking.
///
/// This will early return if there are no elements in the queue.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::try_iter`]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
self.pop_inner(false)
}
/// Pops an element from the priority queue blocking if necessary.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::iter``]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
self.pop_inner(true).map(|e| e.unwrap())
}
/// Returns an iterator over the elements of the queue
/// this iterator will end when all elements have been consumed and will not wait for new ones.
pub(crate) fn try_iter(self) -> TryIter<T> {
TryIter {
receiver: self,
ended: false,
}
}
/// Returns an iterator over the elements of the queue
/// this iterator will wait for new elements if the queue is empty.
pub(crate) fn iter(self) -> Iter<T> {
Iter(self)
}
#[inline(always)]
// algorithm is the loaded die from biased coin from
// https://www.keithschwarz.com/darts-dice-coins/
fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
use Priority as P;
let mut queues = if !block {
let Some(queues) = self.state.try_recv()? else {
return Ok(None);
};
queues
} else {
self.state.recv()?
};
let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
let mut mass = high + medium + low; //%
if !queues.high_priority.is_empty() {
let flip = self.rand.random_ratio(P::High.probability(), mass);
if flip {
return Ok(queues.high_priority.pop());
}
mass -= P::High.probability();
}
if !queues.medium_priority.is_empty() {
let flip = self.rand.random_ratio(P::Medium.probability(), mass);
if flip {
return Ok(queues.medium_priority.pop());
}
mass -= P::Medium.probability();
}
if !queues.low_priority.is_empty() {
let flip = self.rand.random_ratio(P::Low.probability(), mass);
if flip {
return Ok(queues.low_priority.pop());
}
}
Ok(None)
}
}
impl<T> Drop for PriorityQueueReceiver<T> {
fn drop(&mut self) {
self.state
.receiver_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
/// If None is returned the sender disconnected
pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
impl<T> Iterator for Iter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_inner(true).ok().flatten()
}
}
impl<T> FusedIterator for Iter<T> {}
/// If None is returned there are no more elements in the queue
pub(crate) struct TryIter<T> {
receiver: PriorityQueueReceiver<T>,
ended: bool,
}
impl<T> Iterator for TryIter<T> {
type Item = Result<T, RecvError>;
fn next(&mut self) -> Option<Self::Item> {
if self.ended {
return None;
}
let res = self.receiver.pop_inner(false);
self.ended = res.is_err();
res.transpose()
}
}
impl<T> FusedIterator for TryIter<T> {}
#[cfg(test)]
mod tests {
use collections::HashSet;
use super::*;
#[test]
fn all_tasks_get_yielded() {
let (tx, mut rx) = PriorityQueueReceiver::new();
tx.send(Priority::Medium, 20).unwrap();
tx.send(Priority::High, 30).unwrap();
tx.send(Priority::Low, 10).unwrap();
tx.send(Priority::Medium, 21).unwrap();
tx.send(Priority::High, 31).unwrap();
drop(tx);
assert_eq!(
rx.iter().collect::<HashSet<_>>(),
[30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
)
}
#[test]
fn new_high_prio_task_get_scheduled_quickly() {
let (tx, mut rx) = PriorityQueueReceiver::new();
for _ in 0..100 {
tx.send(Priority::Low, 1).unwrap();
}
assert_eq!(rx.pop().unwrap(), 1);
tx.send(Priority::High, 3).unwrap();
assert_eq!(rx.pop().unwrap(), 3);
assert_eq!(rx.pop().unwrap(), 1);
}
}

View File

@@ -20,6 +20,110 @@ pub(crate) type PathVertex_ScaledPixels = PathVertex<ScaledPixels>;
pub(crate) type DrawOrder = u32;
/// Test-only scene snapshot for inspecting rendered content.
#[cfg(any(test, feature = "test-support"))]
pub mod test_scene {
use crate::{Bounds, Hsla, Point, ScaledPixels};
/// A rendered quad (background, border, cursor, selection, etc.)
#[derive(Debug, Clone)]
pub struct RenderedQuad {
/// Bounds in scaled pixels.
pub bounds: Bounds<ScaledPixels>,
/// Background color (if solid).
pub background_color: Option<Hsla>,
/// Border color.
pub border_color: Hsla,
}
/// A rendered text glyph.
#[derive(Debug, Clone)]
pub struct RenderedGlyph {
/// Origin position in scaled pixels.
pub origin: Point<ScaledPixels>,
/// Size in scaled pixels.
pub size: crate::Size<ScaledPixels>,
/// Color of the glyph.
pub color: Hsla,
}
/// Snapshot of scene contents for testing.
#[derive(Debug, Default)]
pub struct SceneSnapshot {
/// All rendered quads.
pub quads: Vec<RenderedQuad>,
/// All rendered text glyphs.
pub glyphs: Vec<RenderedGlyph>,
/// Number of shadow primitives.
pub shadow_count: usize,
/// Number of path primitives.
pub path_count: usize,
/// Number of underline primitives.
pub underline_count: usize,
/// Number of polychrome sprites (images, emoji).
pub polychrome_sprite_count: usize,
/// Number of surface primitives.
pub surface_count: usize,
}
impl SceneSnapshot {
/// Get unique Y positions of quads, sorted.
pub fn quad_y_positions(&self) -> Vec<f32> {
let mut positions: Vec<f32> = self.quads.iter().map(|q| q.bounds.origin.y.0).collect();
positions.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
positions.dedup();
positions
}
/// Get unique Y positions of glyphs, sorted.
pub fn glyph_y_positions(&self) -> Vec<f32> {
let mut positions: Vec<f32> = self.glyphs.iter().map(|g| g.origin.y.0).collect();
positions.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
positions.dedup();
positions
}
/// Find quads within a Y range.
pub fn quads_in_y_range(&self, min_y: f32, max_y: f32) -> Vec<&RenderedQuad> {
self.quads
.iter()
.filter(|q| {
let y = q.bounds.origin.y.0;
y >= min_y && y < max_y
})
.collect()
}
/// Find glyphs within a Y range.
pub fn glyphs_in_y_range(&self, min_y: f32, max_y: f32) -> Vec<&RenderedGlyph> {
self.glyphs
.iter()
.filter(|g| {
let y = g.origin.y.0;
y >= min_y && y < max_y
})
.collect()
}
/// Debug summary string.
pub fn summary(&self) -> String {
format!(
"quads: {}, glyphs: {}, shadows: {}, paths: {}, underlines: {}, polychrome: {}, surfaces: {}",
self.quads.len(),
self.glyphs.len(),
self.shadow_count,
self.path_count,
self.underline_count,
self.polychrome_sprite_count,
self.surface_count,
)
}
}
}
#[cfg(any(test, feature = "test-support"))]
pub use test_scene::*;
#[derive(Default)]
pub(crate) struct Scene {
pub(crate) paint_operations: Vec<PaintOperation>,
@@ -124,6 +228,40 @@ impl Scene {
}
}
/// Create a snapshot of the scene for testing.
#[cfg(any(test, feature = "test-support"))]
pub fn snapshot(&self) -> SceneSnapshot {
let quads = self
.quads
.iter()
.map(|q| RenderedQuad {
bounds: q.bounds,
background_color: q.background.as_solid(),
border_color: q.border_color,
})
.collect();
let glyphs = self
.monochrome_sprites
.iter()
.map(|s| RenderedGlyph {
origin: s.bounds.origin,
size: s.bounds.size,
color: s.color,
})
.collect();
SceneSnapshot {
quads,
glyphs,
shadow_count: self.shadows.len(),
path_count: self.paths.len(),
underline_count: self.underlines.len(),
polychrome_sprite_count: self.polychrome_sprites.len(),
surface_count: self.surfaces.len(),
}
}
pub fn finish(&mut self) {
self.shadows.sort_by_key(|shadow| shadow.order);
self.quads.sort_by_key(|quad| quad.order);
@@ -620,7 +758,7 @@ impl Default for TransformationMatrix {
#[repr(C)]
pub(crate) struct MonochromeSprite {
pub order: DrawOrder,
pub pad: u32, // align to 8 bytes
pub pad: u32,
pub bounds: Bounds<ScaledPixels>,
pub content_mask: ContentMask<ScaledPixels>,
pub color: Hsla,
@@ -638,7 +776,7 @@ impl From<MonochromeSprite> for Primitive {
#[repr(C)]
pub(crate) struct PolychromeSprite {
pub order: DrawOrder,
pub pad: u32, // align to 8 bytes
pub pad: u32,
pub grayscale: bool,
pub opacity: f32,
pub bounds: Bounds<ScaledPixels>,

View File

@@ -9,14 +9,15 @@ use crate::{
KeyBinding, KeyContext, KeyDownEvent, KeyEvent, Keystroke, KeystrokeEvent, LayoutId,
LineLayoutIndex, Modifiers, ModifiersChangedEvent, MonochromeSprite, MouseButton, MouseEvent,
MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas, PlatformDisplay, PlatformInput,
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptButton, PromptLevel, Quad,
Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Replay, ResizeEdge,
SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y, ScaledPixels, Scene, Shadow,
SharedString, Size, StrikethroughStyle, Style, SubscriberSet, Subscription, SystemWindowTab,
SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement,
TransformationMatrix, Underline, UnderlineStyle, WindowAppearance, WindowBackgroundAppearance,
WindowBounds, WindowControls, WindowDecorations, WindowOptions, WindowParams, WindowTextSystem,
point, prelude::*, px, rems, size, transparent_black,
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, Priority, PromptButton,
PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams,
Replay, ResizeEdge, SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y,
ScaledPixels, Scene, Shadow, SharedString, Size, StrikethroughStyle, Style, SubscriberSet,
Subscription, SystemWindowTab, SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task,
TextStyle, TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle,
WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations,
WindowOptions, WindowParams, WindowTextSystem, point, prelude::*, px, rems, size,
transparent_black,
};
use anyhow::{Context as _, Result, anyhow};
use collections::{FxHashMap, FxHashSet};
@@ -1725,6 +1726,27 @@ impl Window {
})
}
/// Spawn the future returned by the given closure on the application thread
/// pool, with the given priority. The closure is provided a handle to the
/// current window and an `AsyncWindowContext` for use within your future.
#[track_caller]
pub fn spawn_with_priority<AsyncFn, R>(
&self,
priority: Priority,
cx: &App,
f: AsyncFn,
) -> Task<R>
where
R: 'static,
AsyncFn: AsyncFnOnce(&mut AsyncWindowContext) -> R + 'static,
{
let handle = self.handle;
cx.spawn_with_priority(priority, async move |app| {
let mut async_window_cx = AsyncWindowContext::new_context(app.clone(), handle);
f(&mut async_window_cx).await
})
}
fn bounds_changed(&mut self, cx: &mut App) {
self.scale_factor = self.platform_window.scale_factor();
self.viewport_size = self.platform_window.content_size();

View File

@@ -18,6 +18,7 @@ test-support = []
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
credentials_provider.workspace = true
base64.workspace = true
client.workspace = true
cloud_api_types.workspace = true
@@ -41,6 +42,7 @@ smol.workspace = true
telemetry_events.workspace = true
thiserror.workspace = true
util.workspace = true
zed_env_vars.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -2,7 +2,6 @@ use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, future};
use gpui::{AsyncApp, Context, SharedString, Task};
use language_model::AuthenticateError;
use std::{
fmt::{Display, Formatter},
sync::Arc,
@@ -10,13 +9,16 @@ use std::{
use util::ResultExt as _;
use zed_env_vars::EnvVar;
use crate::AuthenticateError;
/// Manages a single API key for a language model provider. API keys either come from environment
/// variables or the system keychain.
///
/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
/// only used with that URL.
pub struct ApiKeyState {
url: SharedString,
pub url: SharedString,
env_var: EnvVar,
load_status: LoadStatus,
load_task: Option<future::Shared<Task<()>>>,
}
@@ -35,9 +37,10 @@ pub struct ApiKey {
}
impl ApiKeyState {
pub fn new(url: SharedString) -> Self {
pub fn new(url: SharedString, env_var: EnvVar) -> Self {
Self {
url,
env_var,
load_status: LoadStatus::NotPresent,
load_task: None,
}
@@ -47,6 +50,10 @@ impl ApiKeyState {
matches!(self.load_status, LoadStatus::Loaded { .. })
}
pub fn env_var_name(&self) -> &SharedString {
&self.env_var.name
}
pub fn is_from_env_var(&self) -> bool {
match &self.load_status {
LoadStatus::Loaded(ApiKey {
@@ -136,14 +143,13 @@ impl ApiKeyState {
pub fn handle_url_change<Ent: 'static>(
&mut self,
url: SharedString,
env_var: &EnvVar,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
cx: &mut Context<Ent>,
) {
if url != self.url {
if !self.is_from_env_var() {
// loading will continue even though this result task is dropped
let _task = self.load_if_needed(url, env_var, get_this, cx);
let _task = self.load_if_needed(url, get_this, cx);
}
}
}
@@ -156,7 +162,6 @@ impl ApiKeyState {
pub fn load_if_needed<Ent: 'static>(
&mut self,
url: SharedString,
env_var: &EnvVar,
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
cx: &mut Context<Ent>,
) -> Task<Result<(), AuthenticateError>> {
@@ -166,10 +171,10 @@ impl ApiKeyState {
return Task::ready(Ok(()));
}
if let Some(key) = &env_var.value
if let Some(key) = &self.env_var.value
&& !key.is_empty()
{
let api_key = ApiKey::from_env(env_var.name.clone(), key);
let api_key = ApiKey::from_env(self.env_var.name.clone(), key);
self.url = url;
self.load_status = LoadStatus::Loaded(api_key);
self.load_task = None;

View File

@@ -1,3 +1,4 @@
mod api_key;
mod model;
mod rate_limiter;
mod registry;
@@ -30,6 +31,7 @@ use std::{fmt, io};
use thiserror::Error;
use util::serde::is_default;
pub use crate::api_key::{ApiKey, ApiKeyState};
pub use crate::model::*;
pub use crate::rate_limiter::*;
pub use crate::registry::*;
@@ -37,6 +39,7 @@ pub use crate::request::*;
pub use crate::role::*;
pub use crate::telemetry::*;
pub use crate::tool_schema::LanguageModelToolSchemaFormat;
pub use zed_env_vars::{EnvVar, env_var};
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
@@ -609,6 +612,11 @@ pub trait LanguageModel: Send + Sync {
false
}
/// Returns whether this model or provider supports streaming tool calls;
fn supports_streaming_tools(&self) -> bool {
false
}
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
LanguageModelToolSchemaFormat::JsonSchema
}
@@ -763,6 +771,21 @@ pub trait LanguageModelExt: LanguageModel {
}
impl LanguageModelExt for dyn LanguageModel {}
impl std::fmt::Debug for dyn LanguageModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("<dyn LanguageModel>")
.field("id", &self.id())
.field("name", &self.name())
.field("provider_id", &self.provider_id())
.field("provider_name", &self.provider_name())
.field("upstream_provider_name", &self.upstream_provider_name())
.field("upstream_provider_id", &self.upstream_provider_id())
.field("upstream_provider_id", &self.upstream_provider_id())
.field("supports_streaming_tools", &self.supports_streaming_tools())
.finish()
}
}
/// An error that occurred when trying to authenticate the language model provider.
#[derive(Debug, Error)]
pub enum AuthenticateError {

View File

@@ -60,7 +60,6 @@ ui_input.workspace = true
util.workspace = true
vercel = { workspace = true, features = ["schemars"] }
x_ai = { workspace = true, features = ["schemars"] }
zed_env_vars.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -7,10 +7,8 @@ use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
mod api_key;
pub mod provider;
mod settings;
pub mod ui;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;

View File

@@ -8,25 +8,21 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::B
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
ApiKeyState, AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModel,
LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, env_var,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::{EnvVar, env_var};
use crate::api_key::ApiKeyState;
use crate::ui::{ConfiguredApiCard, InstructionListItem};
pub use settings::AnthropicAvailableModel as AvailableModel;
@@ -65,12 +61,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -79,17 +71,13 @@ impl AnthropicLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -362,6 +350,10 @@ impl LanguageModel for AnthropicModel {
true
}
fn supports_streaming_tools(&self) -> bool {
true
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
@@ -937,14 +929,12 @@ impl Render for ConfigurationView {
.child(
List::new()
.child(
InstructionListItem::new(
"Create one by visiting",
Some("Anthropic's settings"),
Some("https://console.anthropic.com/settings/keys")
)
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("Anthropic's settings", "https://console.anthropic.com/settings/keys"))
)
.child(
InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
)
)
.child(self.api_key_editor.clone())
@@ -953,7 +943,8 @@ impl Render for ConfigurationView {
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small)
.color(Color::Muted),
.color(Color::Muted)
.mt_0p5(),
)
.into_any_element()
} else {

View File

@@ -2,7 +2,6 @@ use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use crate::ui::{ConfiguredApiCard, InstructionListItem};
use anyhow::{Context as _, Result, anyhow};
use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
use aws_config::{BehaviorVersion, Region};
@@ -44,7 +43,7 @@ use serde_json::Value;
use settings::{BedrockAvailableModel as AvailableModel, Settings, SettingsStore};
use smol::lock::OnceCell;
use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
@@ -1250,18 +1249,14 @@ impl Render for ConfigurationView {
.child(
List::new()
.child(
InstructionListItem::new(
"Grant permissions to the strategy you'll use according to the:",
Some("Prerequisites"),
Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
)
ListBulletItem::new("")
.child(Label::new("Grant permissions to the strategy you'll use according to the:"))
.child(ButtonLink::new("Prerequisites", "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"))
)
.child(
InstructionListItem::new(
"Select the models you would like access to:",
Some("Bedrock Model Catalog"),
Some("https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"),
)
ListBulletItem::new("")
.child(Label::new("Select the models you would like access to:"))
.child(ButtonLink::new("Bedrock Model Catalog", "https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess"))
)
)
.child(self.render_static_credentials_ui())
@@ -1302,22 +1297,22 @@ impl ConfigurationView {
)
.child(
List::new()
.child(InstructionListItem::new(
"Create an IAM user in the AWS console with programmatic access",
Some("IAM Console"),
Some("https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"),
))
.child(InstructionListItem::new(
"Attach the necessary Bedrock permissions to this ",
Some("user"),
Some("https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"),
))
.child(InstructionListItem::text_only(
"Copy the access key ID and secret access key when provided",
))
.child(InstructionListItem::text_only(
"Enter these credentials below",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create an IAM user in the AWS console with programmatic access"))
.child(ButtonLink::new("IAM Console", "https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users"))
)
.child(
ListBulletItem::new("")
.child(Label::new("Attach the necessary Bedrock permissions to this"))
.child(ButtonLink::new("user", "https://docs.aws.amazon.com/bedrock/latest/userguide/inference-prereq.html"))
)
.child(
ListBulletItem::new("Copy the access key ID and secret access key when provided")
)
.child(
ListBulletItem::new("Enter these credentials below")
)
)
.child(self.access_key_id_editor.clone())
.child(self.secret_access_key_editor.clone())

View File

@@ -602,6 +602,10 @@ impl LanguageModel for CloudLanguageModel {
self.model.supports_images
}
fn supports_streaming_tools(&self) -> bool {
self.model.supports_streaming_tools
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto

View File

@@ -14,7 +14,7 @@ use copilot::{Copilot, Status};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, Stream, StreamExt};
use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg};
use gpui::{AnyView, App, AsyncApp, Entity, Subscription, Task};
use http_client::StatusCode;
use language::language_settings::all_language_settings;
use language_model::{
@@ -26,11 +26,9 @@ use language_model::{
StopReason, TokenUsage,
};
use settings::SettingsStore;
use ui::{CommonAnimationExt, prelude::*};
use ui::prelude::*;
use util::debug_panic;
use crate::ui::ConfiguredApiCard;
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
const PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("GitHub Copilot Chat");
@@ -179,8 +177,18 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
_: &mut Window,
cx: &mut App,
) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, cx)).into()
cx.new(|cx| {
copilot::ConfigurationView::new(
|cx| {
CopilotChat::global(cx)
.map(|m| m.read(cx).is_authenticated())
.unwrap_or(false)
},
copilot::ConfigurationMode::Chat,
cx,
)
})
.into()
}
fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
@@ -1474,92 +1482,3 @@ mod tests {
);
}
}
struct ConfigurationView {
copilot_status: Option<copilot::Status>,
state: Entity<State>,
_subscription: Option<Subscription>,
}
impl ConfigurationView {
pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
let copilot = Copilot::global(cx);
Self {
copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
state,
_subscription: copilot.as_ref().map(|copilot| {
cx.observe(copilot, |this, model, cx| {
this.copilot_status = Some(model.read(cx).status());
cx.notify();
})
}),
}
}
}
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
if self.state.read(cx).is_authenticated(cx) {
ConfiguredApiCard::new("Authorized")
.button_label("Sign Out")
.on_click(|_, window, cx| {
window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
})
.into_any_element()
} else {
let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4);
const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
match &self.copilot_status {
Some(status) => match status {
Status::Starting { task: _ } => h_flex()
.gap_2()
.child(loading_icon)
.child(Label::new("Starting Copilot…"))
.into_any_element(),
Status::SigningIn { prompt: _ }
| Status::SignedOut {
awaiting_signing_in: true,
} => h_flex()
.gap_2()
.child(loading_icon)
.child(Label::new("Signing into Copilot…"))
.into_any_element(),
Status::Error(_) => {
const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
v_flex()
.gap_6()
.child(Label::new(LABEL))
.child(svg().size_8().path(IconName::CopilotError.path()))
.into_any_element()
}
_ => {
const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
v_flex()
.gap_2()
.child(Label::new(LABEL))
.child(
Button::new("sign_in", "Sign in to use GitHub Copilot")
.full_width()
.style(ButtonStyle::Outlined)
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.on_click(|_, window, cx| {
copilot::initiate_sign_in(window, cx)
}),
)
.into_any_element()
}
},
None => v_flex()
.gap_6()
.child(Label::new(ERROR_LABEL))
.into_any_element(),
}
}
}
}

View File

@@ -7,11 +7,11 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
};
pub use settings::DeepseekAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
@@ -19,13 +19,9 @@ use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::{EnvVar, env_var};
use crate::ui::ConfiguredApiCard;
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
@@ -67,12 +63,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -81,17 +73,13 @@ impl DeepSeekLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -632,12 +620,15 @@ impl Render for ConfigurationView {
.child(Label::new("To use DeepSeek in Zed, you need an API key:"))
.child(
List::new()
.child(InstructionListItem::new(
"Get your API key from the",
Some("DeepSeek console"),
Some("https://platform.deepseek.com/api_keys"),
))
.child(InstructionListItem::text_only(
.child(
ListBulletItem::new("")
.child(Label::new("Get your API key from the"))
.child(ButtonLink::new(
"DeepSeek console",
"https://platform.deepseek.com/api_keys",
)),
)
.child(ListBulletItem::new(
"Paste your API key below and hit enter to start using the assistant",
)),
)

View File

@@ -9,7 +9,7 @@ use google_ai::{
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
@@ -28,14 +28,11 @@ use std::sync::{
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::EnvVar;
use crate::api_key::ApiKey;
use crate::api_key::ApiKeyState;
use crate::ui::{ConfiguredApiCard, InstructionListItem};
use language_model::{ApiKey, ApiKeyState};
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
@@ -87,12 +84,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = GoogleLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -101,17 +94,13 @@ impl GoogleLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -873,14 +862,14 @@ impl Render for ConfigurationView {
})))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("Google AI's console"),
Some("https://aistudio.google.com/app/apikey"),
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the assistant",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
)
)
.child(self.api_key_editor.clone())
.child(

View File

@@ -20,11 +20,10 @@ use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use ui::{ButtonLike, Indicator, InlineCode, List, ListBulletItem, prelude::*};
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
@@ -686,12 +685,14 @@ impl Render for ConfigurationView {
.child(
v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
List::new()
.child(InstructionListItem::text_only(
.child(ListBulletItem::new(
"LM Studio needs to be running with at least one model downloaded.",
))
.child(InstructionListItem::text_only(
"To get your first model, try running `lms get qwen2.5-coder-7b`",
)),
.child(
ListBulletItem::new("")
.child(Label::new("To get your first model, try running"))
.child(InlineCode::new("lms get qwen2.5-coder-7b")),
),
),
)
.child(

View File

@@ -1,31 +1,27 @@
use anyhow::{Result, anyhow};
use collections::BTreeMap;
use fs::Fs;
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
};
use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
pub use mistral::{CODESTRAL_API_URL, MISTRAL_API_URL, StreamResponse};
pub use settings::MistralAvailableModel as AvailableModel;
use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
use settings::{Settings, SettingsStore};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use std::sync::{Arc, LazyLock, OnceLock};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::{EnvVar, env_var};
use crate::ui::ConfiguredApiCard;
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
@@ -35,6 +31,7 @@ static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
const CODESTRAL_API_KEY_ENV_VAR_NAME: &str = "CODESTRAL_API_KEY";
static CODESTRAL_API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(CODESTRAL_API_KEY_ENV_VAR_NAME);
static CODESTRAL_API_KEY: OnceLock<Entity<ApiKeyState>> = OnceLock::new();
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@@ -44,12 +41,22 @@ pub struct MistralSettings {
pub struct MistralLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: Entity<State>,
pub state: Entity<State>,
}
pub struct State {
api_key_state: ApiKeyState,
codestral_api_key_state: ApiKeyState,
codestral_api_key_state: Entity<ApiKeyState>,
}
pub fn codestral_api_key(cx: &mut App) -> Entity<ApiKeyState> {
return CODESTRAL_API_KEY
.get_or_init(|| {
cx.new(|_| {
ApiKeyState::new(CODESTRAL_API_URL.into(), CODESTRAL_API_KEY_ENV_VAR.clone())
})
})
.clone();
}
impl State {
@@ -63,39 +70,19 @@ impl State {
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
fn set_codestral_api_key(
&mut self,
api_key: Option<String>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.codestral_api_key_state.store(
CODESTRAL_API_URL.into(),
api_key,
|this| &mut this.codestral_api_key_state,
cx,
)
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = MistralLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
fn authenticate_codestral(
&mut self,
cx: &mut Context<Self>,
) -> Task<Result<(), AuthenticateError>> {
self.codestral_api_key_state.load_if_needed(
CODESTRAL_API_URL.into(),
&CODESTRAL_API_KEY_ENV_VAR,
|this| &mut this.codestral_api_key_state,
cx,
)
self.codestral_api_key_state.update(cx, |state, cx| {
state.load_if_needed(CODESTRAL_API_URL.into(), |state| state, cx)
})
}
}
@@ -116,18 +103,14 @@ impl MistralLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
codestral_api_key_state: ApiKeyState::new(CODESTRAL_API_URL.into()),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
codestral_api_key_state: codestral_api_key(cx),
}
});
@@ -142,7 +125,11 @@ impl MistralLanguageModelProvider {
}
pub fn codestral_api_key(&self, url: &str, cx: &App) -> Option<Arc<str>> {
self.state.read(cx).codestral_api_key_state.key(url)
self.state
.read(cx)
.codestral_api_key_state
.read(cx)
.key(url)
}
fn create_language_model(&self, model: mistral::Model) -> Arc<dyn LanguageModel> {
@@ -159,7 +146,7 @@ impl MistralLanguageModelProvider {
&crate::AllLanguageModelSettings::get_global(cx).mistral
}
fn api_url(cx: &App) -> SharedString {
pub fn api_url(cx: &App) -> SharedString {
let api_url = &Self::settings(cx).api_url;
if api_url.is_empty() {
mistral::MISTRAL_API_URL.into()
@@ -747,7 +734,6 @@ struct RawToolCall {
struct ConfigurationView {
api_key_editor: Entity<InputField>,
codestral_api_key_editor: Entity<InputField>,
state: Entity<State>,
load_credentials_task: Option<Task<()>>,
}
@@ -756,8 +742,6 @@ impl ConfigurationView {
fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor =
cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
let codestral_api_key_editor =
cx.new(|cx| InputField::new(window, cx, "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
cx.observe(&state, |_, _, cx| {
cx.notify();
@@ -774,12 +758,6 @@ impl ConfigurationView {
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
if let Some(task) = state
.update(cx, |state, cx| state.authenticate_codestral(cx))
.log_err()
{
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
@@ -791,7 +769,6 @@ impl ConfigurationView {
Self {
api_key_editor,
codestral_api_key_editor,
state,
load_credentials_task,
}
@@ -829,110 +806,9 @@ impl ConfigurationView {
.detach_and_log_err(cx);
}
fn save_codestral_api_key(
&mut self,
_: &menu::Confirm,
window: &mut Window,
cx: &mut Context<Self>,
) {
let api_key = self
.codestral_api_key_editor
.read(cx)
.text(cx)
.trim()
.to_string();
if api_key.is_empty() {
return;
}
// url changes can cause the editor to be displayed again
self.codestral_api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| {
state.set_codestral_api_key(Some(api_key), cx)
})?
.await?;
cx.update(|_window, cx| {
set_edit_prediction_provider(EditPredictionProvider::Codestral, cx)
})
})
.detach_and_log_err(cx);
}
fn reset_codestral_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.codestral_api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_codestral_api_key(None, cx))?
.await?;
cx.update(|_window, cx| set_edit_prediction_provider(EditPredictionProvider::Zed, cx))
})
.detach_and_log_err(cx);
}
fn should_render_api_key_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
fn render_codestral_api_key_editor(&mut self, cx: &mut Context<Self>) -> AnyElement {
let key_state = &self.state.read(cx).codestral_api_key_state;
let should_show_editor = !key_state.has_key();
let env_var_set = key_state.is_from_env_var();
let configured_card_label = if env_var_set {
format!("API key set in {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable")
} else {
"Codestral API key configured".to_string()
};
if should_show_editor {
v_flex()
.id("codestral")
.size_full()
.mt_2()
.on_action(cx.listener(Self::save_codestral_api_key))
.child(Label::new(
"To use Codestral as an edit prediction provider, \
you need to add a Codestral-specific API key. Follow these steps:",
))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("the Codestral section of Mistral's console"),
Some("https://console.mistral.ai/codestral"),
))
.child(InstructionListItem::text_only("Paste your API key below and hit enter")),
)
.child(self.codestral_api_key_editor.clone())
.child(
Label::new(
format!("You can also assign the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
).into_any()
} else {
ConfiguredApiCard::new(configured_card_label)
.disabled(env_var_set)
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
.when(env_var_set, |this| {
this.tooltip_label(format!(
"To reset your API key, \
unset the {CODESTRAL_API_KEY_ENV_VAR_NAME} environment variable."
))
})
.on_click(
cx.listener(|this, _, window, cx| this.reset_codestral_api_key(window, cx)),
)
.into_any_element()
}
}
}
impl Render for ConfigurationView {
@@ -958,17 +834,17 @@ impl Render for ConfigurationView {
.child(Label::new("To use Zed's agent with Mistral, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("Mistral's console"),
Some("https://console.mistral.ai/api-keys"),
))
.child(InstructionListItem::text_only(
"Ensure your Mistral account has credits",
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the assistant",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("Mistral's console", "https://console.mistral.ai/api-keys"))
)
.child(
ListBulletItem::new("Ensure your Mistral account has credits")
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
),
)
.child(self.api_key_editor.clone())
.child(
@@ -977,7 +853,6 @@ impl Render for ConfigurationView {
)
.size(LabelSize::Small).color(Color::Muted),
)
.child(self.render_codestral_api_key_editor(cx))
.into_any()
} else {
v_flex()
@@ -994,24 +869,11 @@ impl Render for ConfigurationView {
))
}),
)
.child(self.render_codestral_api_key_editor(cx))
.into_any()
}
}
}
fn set_edit_prediction_provider(provider: EditPredictionProvider, cx: &mut App) {
let fs = <dyn Fs>::global(cx);
update_settings_file(fs, cx, move |settings, _| {
settings
.project
.all_languages
.features
.get_or_insert_default()
.edit_prediction_provider = Some(provider);
});
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -5,11 +5,11 @@ use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
};
use menu;
use ollama::{
@@ -22,13 +22,13 @@ use std::pin::Pin;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::HashMap, sync::Arc};
use ui::{ButtonLike, ElevationIndex, List, Tooltip, prelude::*};
use ui::{
ButtonLike, ButtonLink, ConfiguredApiCard, ElevationIndex, InlineCode, List, ListBulletItem,
Tooltip, prelude::*,
};
use ui_input::InputField;
use zed_env_vars::{EnvVar, env_var};
use crate::AllLanguageModelSettings;
use crate::api_key::ApiKeyState;
use crate::ui::{ConfiguredApiCard, InstructionListItem};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
@@ -80,12 +80,9 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OllamaLanguageModelProvider::api_url(cx);
let task = self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
let task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
// Always try to fetch models - if no API key is needed (local Ollama), it will work
// If API key is needed and provided, it will work
@@ -185,7 +182,7 @@ impl OllamaLanguageModelProvider {
http_client,
fetched_models: Default::default(),
fetch_model_task: None,
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
}),
};
@@ -733,15 +730,17 @@ impl ConfigurationView {
.child(Label::new("To use local Ollama:"))
.child(
List::new()
.child(InstructionListItem::new(
"Download and install Ollama from",
Some("ollama.com"),
Some("https://ollama.com/download"),
))
.child(InstructionListItem::text_only(
"Start Ollama and download a model: `ollama run gpt-oss:20b`",
))
.child(InstructionListItem::text_only(
.child(
ListBulletItem::new("")
.child(Label::new("Download and install Ollama from"))
.child(ButtonLink::new("ollama.com", "https://ollama.com/download")),
)
.child(
ListBulletItem::new("")
.child(Label::new("Start Ollama and download a model:"))
.child(InlineCode::new("ollama run gpt-oss:20b")),
)
.child(ListBulletItem::new(
"Click 'Connect' below to start using Ollama in Zed",
)),
)

View File

@@ -5,11 +5,11 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
};
use menu;
use open_ai::{
@@ -20,13 +20,9 @@ use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::{EnvVar, env_var};
use crate::ui::ConfiguredApiCard;
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
@@ -62,12 +58,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OpenAiLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -76,17 +68,13 @@ impl OpenAiLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -790,17 +778,17 @@ impl Render for ConfigurationView {
.child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("OpenAI's console"),
Some("https://platform.openai.com/api-keys"),
))
.child(InstructionListItem::text_only(
"Ensure your OpenAI account has credits",
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the assistant",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("OpenAI's console", "https://platform.openai.com/api-keys"))
)
.child(
ListBulletItem::new("Ensure your OpenAI account has credits")
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
),
)
.child(self.api_key_editor.clone())
.child(

View File

@@ -4,10 +4,10 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
};
use menu;
use open_ai::{ResponseStreamEvent, stream_completion};
@@ -16,9 +16,7 @@ use std::sync::Arc;
use ui::{ElevationIndex, Tooltip, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::EnvVar;
use crate::api_key::ApiKeyState;
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
@@ -38,7 +36,6 @@ pub struct OpenAiCompatibleLanguageModelProvider {
pub struct State {
id: Arc<str>,
api_key_env_var: EnvVar,
api_key_state: ApiKeyState,
settings: OpenAiCompatibleSettings,
}
@@ -56,12 +53,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = SharedString::new(self.settings.api_url.clone());
self.api_key_state.load_if_needed(
api_url,
&self.api_key_env_var,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -83,7 +76,6 @@ impl OpenAiCompatibleLanguageModelProvider {
let api_url = SharedString::new(settings.api_url.as_str());
this.api_key_state.handle_url_change(
api_url,
&this.api_key_env_var,
|this| &mut this.api_key_state,
cx,
);
@@ -95,8 +87,10 @@ impl OpenAiCompatibleLanguageModelProvider {
let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
State {
id: id.clone(),
api_key_env_var: EnvVar::new(api_key_env_var_name),
api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())),
api_key_state: ApiKeyState::new(
SharedString::new(settings.api_url.as_str()),
EnvVar::new(api_key_env_var_name),
),
settings,
}
});
@@ -437,7 +431,7 @@ impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let state = self.state.read(cx);
let env_var_set = state.api_key_state.is_from_env_var();
let env_var_name = &state.api_key_env_var.name;
let env_var_name = state.api_key_state.env_var_name();
let api_key_section = if self.should_render_editor(cx) {
v_flex()

View File

@@ -4,11 +4,12 @@ use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason, TokenUsage, env_var,
};
use open_router::{
Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models,
@@ -17,13 +18,9 @@ use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsSto
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::{Arc, LazyLock};
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::{EnvVar, env_var};
use crate::ui::ConfiguredApiCard;
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
@@ -62,12 +59,9 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let task = self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
let task = self
.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx);
cx.spawn(async move |this, cx| {
let result = task.await;
@@ -135,7 +129,7 @@ impl OpenRouterLanguageModelProvider {
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
@@ -830,17 +824,15 @@ impl Render for ConfigurationView {
.child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create an API key by visiting",
Some("OpenRouter's console"),
Some("https://openrouter.ai/keys"),
))
.child(InstructionListItem::text_only(
"Ensure your OpenRouter account has credits",
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the assistant",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create an API key by visiting"))
.child(ButtonLink::new("OpenRouter's console", "https://openrouter.ai/keys"))
)
.child(ListBulletItem::new("Ensure your OpenRouter account has credits")
)
.child(ListBulletItem::new("Paste your API key below and hit enter to start using the assistant")
),
)
.child(self.api_key_editor.clone())
.child(

View File

@@ -4,26 +4,20 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, env_var,
};
use open_ai::ResponseStreamEvent;
pub use settings::VercelAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use vercel::{Model, VERCEL_API_URL};
use zed_env_vars::{EnvVar, env_var};
use crate::{
api_key::ApiKeyState,
ui::{ConfiguredApiCard, InstructionListItem},
};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
@@ -59,12 +53,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = VercelLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -73,17 +63,13 @@ impl VercelLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -472,14 +458,14 @@ impl Render for ConfigurationView {
.child(Label::new("To use Zed's agent with Vercel v0, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("Vercel v0's console"),
Some("https://v0.dev/chat/settings/keys"),
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the agent",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("Vercel v0's console", "https://v0.dev/chat/settings/keys"))
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
),
)
.child(self.api_key_editor.clone())
.child(

View File

@@ -4,26 +4,21 @@ use futures::{FutureExt, StreamExt, future, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
Role, env_var,
};
use open_ai::ResponseStreamEvent;
pub use settings::XaiAvailableModel as AvailableModel;
use settings::{Settings, SettingsStore};
use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use ui::{List, prelude::*};
use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
use ui_input::InputField;
use util::ResultExt;
use x_ai::{Model, XAI_API_URL};
use zed_env_vars::{EnvVar, env_var};
use crate::{
api_key::ApiKeyState,
ui::{ConfiguredApiCard, InstructionListItem},
};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
@@ -59,12 +54,8 @@ impl State {
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
let api_url = XAiLanguageModelProvider::api_url(cx);
self.api_key_state.load_if_needed(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
)
self.api_key_state
.load_if_needed(api_url, |this| &mut this.api_key_state, cx)
}
}
@@ -73,17 +64,13 @@ impl XAiLanguageModelProvider {
let state = cx.new(|cx| {
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let api_url = Self::api_url(cx);
this.api_key_state.handle_url_change(
api_url,
&API_KEY_ENV_VAR,
|this| &mut this.api_key_state,
cx,
);
this.api_key_state
.handle_url_change(api_url, |this| &mut this.api_key_state, cx);
cx.notify();
})
.detach();
State {
api_key_state: ApiKeyState::new(Self::api_url(cx)),
api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
}
});
@@ -474,14 +461,14 @@ impl Render for ConfigurationView {
.child(Label::new("To use Zed's agent with xAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("xAI console"),
Some("https://console.x.ai/team/default/api-keys"),
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the agent",
)),
.child(
ListBulletItem::new("")
.child(Label::new("Create one by visiting"))
.child(ButtonLink::new("xAI console", "https://console.x.ai/team/default/api-keys"))
)
.child(
ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
),
)
.child(self.api_key_editor.clone())
.child(

View File

@@ -1,4 +0,0 @@
pub mod configured_api_card;
pub mod instruction_list_item;
pub use configured_api_card::ConfiguredApiCard;
pub use instruction_list_item::InstructionListItem;

View File

@@ -1,69 +0,0 @@
use gpui::{AnyElement, IntoElement, ParentElement, SharedString};
use ui::{ListItem, prelude::*};
/// A reusable list item component for adding LLM provider configuration instructions
pub struct InstructionListItem {
label: SharedString,
button_label: Option<SharedString>,
button_link: Option<String>,
}
impl InstructionListItem {
pub fn new(
label: impl Into<SharedString>,
button_label: Option<impl Into<SharedString>>,
button_link: Option<impl Into<String>>,
) -> Self {
Self {
label: label.into(),
button_label: button_label.map(|l| l.into()),
button_link: button_link.map(|l| l.into()),
}
}
pub fn text_only(label: impl Into<SharedString>) -> Self {
Self {
label: label.into(),
button_label: None,
button_link: None,
}
}
}
impl IntoElement for InstructionListItem {
type Element = AnyElement;
fn into_element(self) -> Self::Element {
let item_content = if let (Some(button_label), Some(button_link)) =
(self.button_label, self.button_link)
{
let link = button_link;
let unique_id = SharedString::from(format!("{}-button", self.label));
h_flex()
.flex_wrap()
.child(Label::new(self.label))
.child(
Button::new(unique_id, button_label)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| cx.open_url(&link)),
)
.into_any_element()
} else {
Label::new(self.label).into_any_element()
};
ListItem::new("list-item")
.selectable(false)
.start_slot(
Icon::new(IconName::Dash)
.size(IconSize::XSmall)
.color(Color::Hidden),
)
.child(div().w_full().child(item_content))
.into_any_element()
}
}

View File

@@ -1,5 +1,6 @@
("(" @open ")" @close)
("[" @open "]" @close)
("{" @open "}" @close)
("<" @open ">" @close)
(("\"" @open "\"" @close) (#set! rainbow.exclude))
(("'" @open "'" @close) (#set! rainbow.exclude))

View File

@@ -286,7 +286,7 @@ impl PromptBuilder {
Ok(())
}
pub fn generate_inline_transformation_prompt_v2(
pub fn generate_inline_transformation_prompt_tools(
&self,
language_name: Option<&LanguageName>,
buffer: BufferSnapshot,

View File

@@ -12,7 +12,7 @@ mod session;
use std::{sync::Arc, time::Duration};
use async_dispatcher::{Dispatcher, Runnable, set_dispatcher};
use gpui::{App, PlatformDispatcher, RunnableVariant};
use gpui::{App, PlatformDispatcher, Priority, RunnableVariant};
use project::Fs;
pub use runtimelib::ExecutionState;
@@ -46,7 +46,7 @@ fn zed_dispatcher(cx: &mut App) -> impl Dispatcher {
impl Dispatcher for ZedDispatcher {
fn dispatch(&self, runnable: Runnable) {
self.dispatcher
.dispatch(RunnableVariant::Compat(runnable), None);
.dispatch(RunnableVariant::Compat(runnable), None, Priority::default());
}
fn dispatch_after(&self, duration: Duration, runnable: Runnable) {

View File

@@ -36,7 +36,13 @@ pub struct AgentSettingsContent {
pub default_model: Option<LanguageModelSelection>,
/// Model to use for the inline assistant. Defaults to default_model when not specified.
pub inline_assistant_model: Option<LanguageModelSelection>,
/// Model to use for generating git commit messages. Defaults to default_model when not specified.
/// Model to use for the inline assistant when streaming tools are enabled.
///
/// Default: true
pub inline_assistant_use_streaming_tools: Option<bool>,
/// Model to use for generating git commit messages.
///
/// Default: true
pub commit_message_model: Option<LanguageModelSelection>,
/// Model to use for generating thread summaries. Defaults to default_model when not specified.
pub thread_summary_model: Option<LanguageModelSelection>,
@@ -129,6 +135,9 @@ impl AgentSettingsContent {
model,
});
}
pub fn set_inline_assistant_use_streaming_tools(&mut self, use_tools: bool) {
self.inline_assistant_use_streaming_tools = Some(use_tools);
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.commit_message_model = Some(LanguageModelSelection {

View File

@@ -186,22 +186,20 @@ pub struct CopilotSettingsContent {
pub enterprise_uri: Option<String>,
}
#[with_fallible_options]
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, MergeFrom, PartialEq)]
pub struct CodestralSettingsContent {
/// Model to use for completions.
///
/// Default: "codestral-latest"
#[serde(default)]
pub model: Option<String>,
/// Maximum tokens to generate.
///
/// Default: 150
#[serde(default)]
pub max_tokens: Option<u32>,
/// Api URL to use for completions.
///
/// Default: "https://codestral.mistral.ai"
#[serde(default)]
pub api_url: Option<String>,
}

View File

@@ -18,6 +18,9 @@ test-support = []
[dependencies]
anyhow.workspace = true
bm25 = "2.3.2"
copilot.workspace = true
edit_prediction.workspace = true
language_models.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -38,8 +41,8 @@ strum.workspace = true
telemetry.workspace = true
theme.workspace = true
title_bar.workspace = true
ui.workspace = true
ui_input.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -2,10 +2,12 @@ mod dropdown;
mod font_picker;
mod icon_theme_picker;
mod input_field;
mod section_items;
mod theme_picker;
pub use dropdown::*;
pub use font_picker::font_picker;
pub use icon_theme_picker::icon_theme_picker;
pub use input_field::*;
pub use section_items::*;
pub use theme_picker::theme_picker;

View File

@@ -13,6 +13,7 @@ pub struct SettingsInputField {
tab_index: Option<isize>,
}
// TODO: Update the `ui_input::InputField` to use `window.use_state` and `RenceOnce` and remove this component
impl SettingsInputField {
pub fn new() -> Self {
Self {

View File

@@ -0,0 +1,56 @@
use gpui::{IntoElement, ParentElement, Styled};
use ui::{Divider, DividerColor, prelude::*};
#[derive(IntoElement)]
pub struct SettingsSectionHeader {
icon: Option<IconName>,
label: SharedString,
no_padding: bool,
}
impl SettingsSectionHeader {
pub fn new(label: impl Into<SharedString>) -> Self {
Self {
label: label.into(),
icon: None,
no_padding: false,
}
}
pub fn icon(mut self, icon: IconName) -> Self {
self.icon = Some(icon);
self
}
pub fn no_padding(mut self, no_padding: bool) -> Self {
self.no_padding = no_padding;
self
}
}
impl RenderOnce for SettingsSectionHeader {
fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement {
let label = Label::new(self.label)
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx);
v_flex()
.w_full()
.when(!self.no_padding, |this| this.px_8())
.gap_1p5()
.map(|this| {
if self.icon.is_some() {
this.child(
h_flex()
.gap_1p5()
.child(Icon::new(self.icon.unwrap()).color(Color::Muted))
.child(label),
)
} else {
this.child(label)
}
})
.child(Divider::horizontal().color(DividerColor::BorderFaded))
}
}

View File

@@ -2330,8 +2330,12 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
// Note that `crates/json_schema_store` solves the same problem, there is probably a way to unify the two
items.push(SettingsPageItem::SectionHeader(LANGUAGES_SECTION_HEADER));
items.extend(all_language_names(cx).into_iter().map(|language_name| {
let link = format!("languages.{language_name}");
SettingsPageItem::SubPageLink(SubPageLink {
title: language_name,
description: None,
json_path: Some(link.leak()),
in_json: true,
files: USER | PROJECT,
render: Arc::new(|this, window, cx| {
this.render_sub_page_items(
@@ -6013,7 +6017,7 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "In Text Threads",
title: "Display In Text Threads",
description: "Whether edit predictions are enabled when editing text threads in the agent panel.",
field: Box::new(SettingField {
json_path: Some("edit_prediction.in_text_threads"),
@@ -6027,42 +6031,6 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
metadata: None,
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Copilot Provider",
description: "Use GitHub Copilot as your edit prediction provider.",
field: Box::new(
SettingField {
json_path: Some("edit_prediction.copilot_provider"),
pick: |settings_content| {
settings_content.project.all_languages.edit_predictions.as_ref()?.copilot.as_ref()
},
write: |settings_content, value| {
settings_content.project.all_languages.edit_predictions.get_or_insert_default().copilot = value;
},
}
.unimplemented(),
),
metadata: None,
files: USER | PROJECT,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Codestral Provider",
description: "Use Mistral's Codestral as your edit prediction provider.",
field: Box::new(
SettingField {
json_path: Some("edit_prediction.codestral_provider"),
pick: |settings_content| {
settings_content.project.all_languages.edit_predictions.as_ref()?.codestral.as_ref()
},
write: |settings_content, value| {
settings_content.project.all_languages.edit_predictions.get_or_insert_default().codestral = value;
},
}
.unimplemented(),
),
metadata: None,
files: USER | PROJECT,
}),
]
);
items
@@ -7485,9 +7453,23 @@ fn non_editor_language_settings_data() -> Vec<SettingsPageItem> {
fn edit_prediction_language_settings_section() -> Vec<SettingsPageItem> {
vec![
SettingsPageItem::SectionHeader("Edit Predictions"),
SettingsPageItem::SubPageLink(SubPageLink {
title: "Configure Providers".into(),
json_path: Some("edit_predictions.providers"),
description: Some("Set up different edit prediction providers in complement to Zed's built-in Zeta model.".into()),
in_json: false,
files: USER,
render: Arc::new(|_, window, cx| {
let settings_window = cx.entity();
let page = window.use_state(cx, |_, _| {
crate::pages::EditPredictionSetupPage::new(settings_window)
});
page.into_any_element()
}),
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Show Edit Predictions",
description: "Controls whether edit predictions are shown immediately or manually by triggering `editor::showeditprediction` (false).",
description: "Controls whether edit predictions are shown immediately or manually.",
field: Box::new(SettingField {
json_path: Some("languages.$(language).show_edit_predictions"),
pick: |settings_content| {
@@ -7505,7 +7487,7 @@ fn edit_prediction_language_settings_section() -> Vec<SettingsPageItem> {
files: USER | PROJECT,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Edit Predictions Disabled In",
title: "Disable in Language Scopes",
description: "Controls whether edit predictions are shown in the given language scopes.",
field: Box::new(
SettingField {

View File

@@ -0,0 +1,2 @@
mod edit_prediction_provider_setup;
pub use edit_prediction_provider_setup::EditPredictionSetupPage;

View File

@@ -0,0 +1,365 @@
use edit_prediction::{
ApiKeyState, Zeta2FeatureFlag,
mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
};
use feature_flags::FeatureFlagAppExt as _;
use gpui::{Entity, ScrollHandle, prelude::*};
use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
use crate::{
SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
components::{SettingsInputField, SettingsSectionHeader},
};
pub struct EditPredictionSetupPage {
settings_window: Entity<SettingsWindow>,
scroll_handle: ScrollHandle,
}
impl EditPredictionSetupPage {
pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
Self {
settings_window,
scroll_handle: ScrollHandle::new(),
}
}
}
impl Render for EditPredictionSetupPage {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings_window = self.settings_window.clone();
let providers = [
Some(render_github_copilot_provider(window, cx).into_any_element()),
cx.has_flag::<Zeta2FeatureFlag>().then(|| {
render_api_key_provider(
IconName::Inception,
"Mercury",
"https://platform.inceptionlabs.ai/dashboard/api-keys".into(),
mercury_api_token(cx),
|_cx| MERCURY_CREDENTIALS_URL,
None,
window,
cx,
)
.into_any_element()
}),
cx.has_flag::<Zeta2FeatureFlag>().then(|| {
render_api_key_provider(
IconName::SweepAi,
"Sweep",
"https://app.sweep.dev/".into(),
sweep_api_token(cx),
|_cx| SWEEP_CREDENTIALS_URL,
None,
window,
cx,
)
.into_any_element()
}),
Some(
render_api_key_provider(
IconName::AiMistral,
"Codestral",
"https://console.mistral.ai/codestral".into(),
codestral_api_key(cx),
|cx| language_models::MistralLanguageModelProvider::api_url(cx),
Some(settings_window.update(cx, |settings_window, cx| {
let codestral_settings = codestral_settings();
settings_window
.render_sub_page_items_section(
codestral_settings.iter().enumerate(),
None,
window,
cx,
)
.into_any_element()
})),
window,
cx,
)
.into_any_element(),
),
];
div()
.size_full()
.vertical_scrollbar_for(&self.scroll_handle, window, cx)
.child(
v_flex()
.id("ep-setup-page")
.min_w_0()
.size_full()
.px_8()
.pb_16()
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
.children(providers.into_iter().flatten()),
)
}
}
fn render_api_key_provider(
icon: IconName,
title: &'static str,
link: SharedString,
api_key_state: Entity<ApiKeyState>,
current_url: fn(&mut App) -> SharedString,
additional_fields: Option<AnyElement>,
window: &mut Window,
cx: &mut Context<EditPredictionSetupPage>,
) -> impl IntoElement {
let weak_page = cx.weak_entity();
_ = window.use_keyed_state(title, cx, |_, cx| {
let task = api_key_state.update(cx, |key_state, cx| {
key_state.load_if_needed(current_url(cx), |state| state, cx)
});
cx.spawn(async move |_, cx| {
task.await.ok();
weak_page
.update(cx, |_, cx| {
cx.notify();
})
.ok();
})
});
let (has_key, env_var_name, is_from_env_var) = api_key_state.read_with(cx, |state, _| {
(
state.has_key(),
Some(state.env_var_name().clone()),
state.is_from_env_var(),
)
});
let write_key = move |api_key: Option<String>, cx: &mut App| {
api_key_state
.update(cx, |key_state, cx| {
let url = current_url(cx);
key_state.store(url, api_key, |key_state| key_state, cx)
})
.detach_and_log_err(cx);
};
let base_container = v_flex().id(title).min_w_0().pt_8().gap_1p5();
let header = SettingsSectionHeader::new(title)
.icon(icon)
.no_padding(true);
let button_link_label = format!("{} dashboard", title);
let description = h_flex()
.min_w_0()
.gap_0p5()
.child(
Label::new("Visit the")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
ButtonLink::new(button_link_label, link)
.no_icon(true)
.label_size(LabelSize::Small)
.label_color(Color::Muted),
)
.child(
Label::new("to generate an API key.")
.size(LabelSize::Small)
.color(Color::Muted),
);
let configured_card_label = if is_from_env_var {
"API Key Set in Environment Variable"
} else {
"API Key Configured"
};
let container = if has_key {
base_container.child(header).child(
ConfiguredApiCard::new(configured_card_label)
.button_label("Reset Key")
.button_tab_index(0)
.disabled(is_from_env_var)
.when_some(env_var_name, |this, env_var_name| {
this.when(is_from_env_var, |this| {
this.tooltip_label(format!(
"To reset your API key, unset the {} environment variable.",
env_var_name
))
})
})
.on_click(move |_, _, cx| {
write_key(None, cx);
}),
)
} else {
base_container.child(header).child(
h_flex()
.pt_2p5()
.w_full()
.justify_between()
.child(
v_flex()
.w_full()
.max_w_1_2()
.child(Label::new("API Key"))
.child(description)
.when_some(env_var_name, |this, env_var_name| {
this.child({
let label = format!(
"Or set the {} env var and restart Zed.",
env_var_name.as_ref()
);
Label::new(label).size(LabelSize::Small).color(Color::Muted)
})
}),
)
.child(
SettingsInputField::new()
.tab_index(0)
.with_placeholder("xxxxxxxxxxxxxxxxxxxx")
.on_confirm(move |api_key, cx| {
write_key(api_key.filter(|key| !key.is_empty()), cx);
}),
),
)
};
container.when_some(additional_fields, |this, additional_fields| {
this.child(
div()
.map(|this| if has_key { this.mt_1() } else { this.mt_4() })
.px_neg_8()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.child(additional_fields),
)
})
}
fn codestral_settings() -> Box<[SettingsPageItem]> {
Box::new([
SettingsPageItem::SettingItem(SettingItem {
title: "API URL",
description: "The API URL to use for Codestral.",
field: Box::new(SettingField {
pick: |settings| {
settings
.project
.all_languages
.edit_predictions
.as_ref()?
.codestral
.as_ref()?
.api_url
.as_ref()
},
write: |settings, value| {
settings
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.codestral
.get_or_insert_default()
.api_url = value;
},
json_path: Some("edit_predictions.codestral.api_url"),
}),
metadata: Some(Box::new(SettingsFieldMetadata {
placeholder: Some(CODESTRAL_API_URL),
..Default::default()
})),
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Max Tokens",
description: "The maximum number of tokens to generate.",
field: Box::new(SettingField {
pick: |settings| {
settings
.project
.all_languages
.edit_predictions
.as_ref()?
.codestral
.as_ref()?
.max_tokens
.as_ref()
},
write: |settings, value| {
settings
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.codestral
.get_or_insert_default()
.max_tokens = value;
},
json_path: Some("edit_predictions.codestral.max_tokens"),
}),
metadata: None,
files: USER,
}),
SettingsPageItem::SettingItem(SettingItem {
title: "Model",
description: "The Codestral model id to use.",
field: Box::new(SettingField {
pick: |settings| {
settings
.project
.all_languages
.edit_predictions
.as_ref()?
.codestral
.as_ref()?
.model
.as_ref()
},
write: |settings, value| {
settings
.project
.all_languages
.edit_predictions
.get_or_insert_default()
.codestral
.get_or_insert_default()
.model = value;
},
json_path: Some("edit_predictions.codestral.model"),
}),
metadata: Some(Box::new(SettingsFieldMetadata {
placeholder: Some("codestral-latest"),
..Default::default()
})),
files: USER,
}),
])
}
pub(crate) fn render_github_copilot_provider(
window: &mut Window,
cx: &mut App,
) -> impl IntoElement {
let configuration_view = window.use_state(cx, |_, cx| {
copilot::ConfigurationView::new(
|cx| {
copilot::Copilot::global(cx)
.is_some_and(|copilot| copilot.read(cx).is_authenticated())
},
copilot::ConfigurationMode::EditPrediction,
cx,
)
});
v_flex()
.id("github-copilot")
.min_w_0()
.gap_1p5()
.child(
SettingsSectionHeader::new("GitHub Copilot")
.icon(IconName::Copilot)
.no_padding(true),
)
.child(configuration_view)
}

View File

@@ -1,5 +1,6 @@
mod components;
mod page_data;
mod pages;
use anyhow::Result;
use editor::{Editor, EditorEvent};
@@ -28,9 +29,8 @@ use std::{
};
use title_bar::platform_title_bar::PlatformTitleBar;
use ui::{
Banner, ContextMenu, Divider, DividerColor, DropdownMenu, DropdownStyle, IconButtonShape,
KeyBinding, KeybindingHint, PopoverMenu, Switch, Tooltip, TreeViewItem, WithScrollbar,
prelude::*,
Banner, ContextMenu, Divider, DropdownMenu, DropdownStyle, IconButtonShape, KeyBinding,
KeybindingHint, PopoverMenu, Switch, Tooltip, TreeViewItem, WithScrollbar, prelude::*,
};
use ui_input::{NumberField, NumberFieldType};
use util::{ResultExt as _, paths::PathStyle, rel_path::RelPath};
@@ -38,7 +38,8 @@ use workspace::{AppState, OpenOptions, OpenVisible, Workspace, client_side_decor
use zed_actions::{OpenProjectSettings, OpenSettings, OpenSettingsAt};
use crate::components::{
EnumVariantDropdown, SettingsInputField, font_picker, icon_theme_picker, theme_picker,
EnumVariantDropdown, SettingsInputField, SettingsSectionHeader, font_picker, icon_theme_picker,
theme_picker,
};
const NAVBAR_CONTAINER_TAB_INDEX: isize = 0;
@@ -613,7 +614,10 @@ pub fn open_settings_editor(
app_id: Some(app_id.to_owned()),
window_decorations: Some(window_decorations),
window_min_size: Some(gpui::Size {
width: px(360.0),
// Don't make the settings window thinner than this,
// otherwise, it gets unusable. Users with smaller res monitors
// can customize the height, but not the width.
width: px(900.0),
height: px(240.0),
}),
window_bounds: Some(WindowBounds::centered(scaled_bounds, cx)),
@@ -834,18 +838,9 @@ impl SettingsPageItem {
};
match self {
SettingsPageItem::SectionHeader(header) => v_flex()
.w_full()
.px_8()
.gap_1p5()
.child(
Label::new(SharedString::new_static(header))
.size(LabelSize::Small)
.color(Color::Muted)
.buffer_font(cx),
)
.child(Divider::horizontal().color(DividerColor::BorderFaded))
.into_any_element(),
SettingsPageItem::SectionHeader(header) => {
SettingsSectionHeader::new(SharedString::new_static(header)).into_any_element()
}
SettingsPageItem::SettingItem(setting_item) => {
let (field_with_padding, _) =
render_setting_item_inner(setting_item, true, false, cx);
@@ -869,9 +864,20 @@ impl SettingsPageItem {
.map(apply_padding)
.child(
v_flex()
.relative()
.w_full()
.max_w_1_2()
.child(Label::new(sub_page_link.title.clone())),
.child(Label::new(sub_page_link.title.clone()))
.when_some(
sub_page_link.description.as_ref(),
|this, description| {
this.child(
Label::new(description.clone())
.size(LabelSize::Small)
.color(Color::Muted),
)
},
),
)
.child(
Button::new(
@@ -909,7 +915,13 @@ impl SettingsPageItem {
this.push_sub_page(sub_page_link.clone(), header, cx)
})
}),
),
)
.child(render_settings_item_link(
sub_page_link.title.clone(),
sub_page_link.json_path,
false,
cx,
)),
)
.when(!is_last, |this| this.child(Divider::horizontal()))
.into_any_element(),
@@ -983,20 +995,6 @@ fn render_settings_item(
let (found_in_file, _) = setting_item.field.file_set_in(file.clone(), cx);
let file_set_in = SettingsUiFile::from_settings(found_in_file.clone());
let clipboard_has_link = cx
.read_from_clipboard()
.and_then(|entry| entry.text())
.map_or(false, |maybe_url| {
setting_item.field.json_path().is_some()
&& maybe_url.strip_prefix("zed://settings/") == setting_item.field.json_path()
});
let (link_icon, link_icon_color) = if clipboard_has_link {
(IconName::Check, Color::Success)
} else {
(IconName::Link, Color::Muted)
};
h_flex()
.id(setting_item.title)
.min_w_0()
@@ -1056,42 +1054,62 @@ fn render_settings_item(
)
.child(control)
.when(sub_page_stack().is_empty(), |this| {
// Intentionally using the description to make the icon button
// unique because some items share the same title (e.g., "Font Size")
let icon_button_id =
SharedString::new(format!("copy-link-btn-{}", setting_item.description));
this.child(
div()
.absolute()
.top(rems_from_px(18.))
.map(|this| {
if sub_field {
this.visible_on_hover("setting-sub-item")
.left(rems_from_px(-8.5))
} else {
this.visible_on_hover("setting-item")
.left(rems_from_px(-22.))
}
})
.child({
IconButton::new(icon_button_id, link_icon)
.icon_color(link_icon_color)
.icon_size(IconSize::Small)
.shape(IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Link"))
.when_some(setting_item.field.json_path(), |this, path| {
this.on_click(cx.listener(move |_, _, _, cx| {
let link = format!("zed://settings/{}", path);
cx.write_to_clipboard(ClipboardItem::new_string(link));
cx.notify();
}))
})
}),
)
this.child(render_settings_item_link(
setting_item.description,
setting_item.field.json_path(),
sub_field,
cx,
))
})
}
fn render_settings_item_link(
id: impl Into<ElementId>,
json_path: Option<&'static str>,
sub_field: bool,
cx: &mut Context<'_, SettingsWindow>,
) -> impl IntoElement {
let clipboard_has_link = cx
.read_from_clipboard()
.and_then(|entry| entry.text())
.map_or(false, |maybe_url| {
json_path.is_some() && maybe_url.strip_prefix("zed://settings/") == json_path
});
let (link_icon, link_icon_color) = if clipboard_has_link {
(IconName::Check, Color::Success)
} else {
(IconName::Link, Color::Muted)
};
div()
.absolute()
.top(rems_from_px(18.))
.map(|this| {
if sub_field {
this.visible_on_hover("setting-sub-item")
.left(rems_from_px(-8.5))
} else {
this.visible_on_hover("setting-item")
.left(rems_from_px(-22.))
}
})
.child(
IconButton::new((id.into(), "copy-link-btn"), link_icon)
.icon_color(link_icon_color)
.icon_size(IconSize::Small)
.shape(IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Link"))
.when_some(json_path, |this, path| {
this.on_click(cx.listener(move |_, _, _, cx| {
let link = format!("zed://settings/{}", path);
cx.write_to_clipboard(ClipboardItem::new_string(link));
cx.notify();
}))
}),
)
}
struct SettingItem {
title: &'static str,
description: &'static str,
@@ -1175,6 +1193,12 @@ impl PartialEq for SettingItem {
#[derive(Clone)]
struct SubPageLink {
title: SharedString,
description: Option<SharedString>,
/// See [`SettingField.json_path`]
json_path: Option<&'static str>,
/// Whether or not the settings in this sub page are configurable in settings.json
/// Removes the "Edit in settings.json" button from the page.
in_json: bool,
files: FileMask,
render: Arc<
dyn Fn(&mut SettingsWindow, &mut Window, &mut Context<SettingsWindow>) -> AnyElement
@@ -1835,6 +1859,7 @@ impl SettingsWindow {
header_str = *header;
}
SettingsPageItem::SubPageLink(sub_page_link) => {
json_path = sub_page_link.json_path;
documents.push(bm25::Document {
id: key_index,
contents: [page.title, header_str, sub_page_link.title.as_ref()]
@@ -2758,19 +2783,49 @@ impl SettingsWindow {
page_content
}
fn render_sub_page_items<'a, Items: Iterator<Item = (usize, &'a SettingsPageItem)>>(
fn render_sub_page_items<'a, Items>(
&self,
items: Items,
page_index: Option<usize>,
window: &mut Window,
cx: &mut Context<SettingsWindow>,
) -> impl IntoElement {
let mut page_content = v_flex()
) -> impl IntoElement
where
Items: Iterator<Item = (usize, &'a SettingsPageItem)>,
{
let page_content = v_flex()
.id("settings-ui-page")
.size_full()
.overflow_y_scroll()
.track_scroll(&self.sub_page_scroll_handle);
self.render_sub_page_items_in(page_content, items, page_index, window, cx)
}
fn render_sub_page_items_section<'a, Items>(
&self,
items: Items,
page_index: Option<usize>,
window: &mut Window,
cx: &mut Context<SettingsWindow>,
) -> impl IntoElement
where
Items: Iterator<Item = (usize, &'a SettingsPageItem)>,
{
let page_content = v_flex().id("settings-ui-sub-page-section").size_full();
self.render_sub_page_items_in(page_content, items, page_index, window, cx)
}
fn render_sub_page_items_in<'a, Items>(
&self,
mut page_content: Stateful<Div>,
items: Items,
page_index: Option<usize>,
window: &mut Window,
cx: &mut Context<SettingsWindow>,
) -> impl IntoElement
where
Items: Iterator<Item = (usize, &'a SettingsPageItem)>,
{
let items: Vec<_> = items.collect();
let items_len = items.len();
let mut section_header = None;
@@ -2871,18 +2926,25 @@ impl SettingsWindow {
)
.child(self.render_sub_page_breadcrumbs()),
)
.child(
Button::new("open-in-settings-file", "Edit in settings.json")
.tab_index(0_isize)
.style(ButtonStyle::OutlinedGhost)
.tooltip(Tooltip::for_action_title_in(
"Edit in settings.json",
&OpenCurrentFile,
&self.focus_handle,
))
.on_click(cx.listener(|this, _, window, cx| {
this.open_current_settings_file(window, cx);
})),
.when(
sub_page_stack()
.last()
.is_none_or(|sub_page| sub_page.link.in_json),
|this| {
this.child(
Button::new("open-in-settings-file", "Edit in settings.json")
.tab_index(0_isize)
.style(ButtonStyle::OutlinedGhost)
.tooltip(Tooltip::for_action_title_in(
"Edit in settings.json",
&OpenCurrentFile,
&self.focus_handle,
))
.on_click(cx.listener(|this, _, window, cx| {
this.open_current_settings_file(window, cx);
})),
)
},
)
.into_any_element();

View File

@@ -167,7 +167,7 @@ impl Render for TitleBar {
.child(self.render_project_name(cx))
})
.when(title_bar_settings.show_branch_name, |title_bar| {
title_bar.children(self.render_project_branch(cx))
title_bar.children(self.render_project_repo(cx))
})
})
})
@@ -319,6 +319,27 @@ impl TitleBar {
}
}
fn project_name(&self, cx: &Context<Self>) -> Option<SharedString> {
self.project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| {
let worktree = worktree.read(cx);
let settings_location = SettingsLocation {
worktree_id: worktree.id(),
path: RelPath::empty(),
};
let settings = WorktreeSettings::get(Some(settings_location), cx);
let name = match &settings.project_name {
Some(name) => name.as_str(),
None => worktree.root_name_str(),
};
SharedString::new(name)
})
.next()
}
fn render_remote_project_connection(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let options = self.project.read(cx).remote_connection_options(cx)?;
let host: SharedString = options.display_name().into();
@@ -451,27 +472,10 @@ impl TitleBar {
}
pub fn render_project_name(&self, cx: &mut Context<Self>) -> impl IntoElement {
let name = self
.project
.read(cx)
.visible_worktrees(cx)
.map(|worktree| {
let worktree = worktree.read(cx);
let settings_location = SettingsLocation {
worktree_id: worktree.id(),
path: RelPath::empty(),
};
let settings = WorktreeSettings::get(Some(settings_location), cx);
match &settings.project_name {
Some(name) => name.as_str(),
None => worktree.root_name_str(),
}
})
.next();
let name = self.project_name(cx);
let is_project_selected = name.is_some();
let name = if let Some(name) = name {
util::truncate_and_trailoff(name, MAX_PROJECT_NAME_LENGTH)
util::truncate_and_trailoff(&name, MAX_PROJECT_NAME_LENGTH)
} else {
"Open recent project".to_string()
};
@@ -500,9 +504,10 @@ impl TitleBar {
}))
}
pub fn render_project_branch(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
pub fn render_project_repo(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let settings = TitleBarSettings::get_global(cx);
let repository = self.project.read(cx).active_repository(cx)?;
let repository_count = self.project.read(cx).repositories(cx).len();
let workspace = self.workspace.upgrade()?;
let repo = repository.read(cx);
let branch_name = repo
@@ -519,6 +524,19 @@ impl TitleBar {
.collect::<String>()
})
})?;
let project_name = self.project_name(cx);
let repo_name = repo
.work_directory_abs_path
.file_name()
.and_then(|name| name.to_str())
.map(SharedString::new);
let show_repo_name =
repository_count > 1 && repo.branch.is_some() && repo_name != project_name;
let branch_name = if let Some(repo_name) = repo_name.filter(|_| show_repo_name) {
format!("{repo_name}/{branch_name}")
} else {
branch_name
};
Some(
Button::new("project_branch_trigger", branch_name)

View File

@@ -1,3 +1,4 @@
mod ai;
mod avatar;
mod banner;
mod button;
@@ -16,6 +17,7 @@ mod icon;
mod image;
mod indent_guides;
mod indicator;
mod inline_code;
mod keybinding;
mod keybinding_hint;
mod label;
@@ -43,6 +45,7 @@ mod tree_view_item;
#[cfg(feature = "stories")]
mod stories;
pub use ai::*;
pub use avatar::*;
pub use banner::*;
pub use button::*;
@@ -61,6 +64,7 @@ pub use icon::*;
pub use image::*;
pub use indent_guides::*;
pub use indicator::*;
pub use inline_code::*;
pub use keybinding::*;
pub use keybinding_hint::*;
pub use label::*;

View File

@@ -0,0 +1,3 @@
mod configured_api_card;
pub use configured_api_card::*;

Some files were not shown because too many files have changed in this diff Show More