Compare commits

..

27 Commits

Author SHA1 Message Date
Bennet Bo Fenner
1dd21ad8f7 agent: Re-introduce preview when hovering selection in context strip 2025-04-25 12:16:32 +02:00
Marshall Bowers
187f851613 feature_flags: Add FeatureFlag suffix to feature flag types (#29392)
This PR adds the `FeatureFlag` suffix to the feature flag types that
were missing them.

This makes the names easier to search in the codebase.

Release Notes:

- N/A
2025-04-25 04:07:49 +00:00
Marshall Bowers
a77db45865 feature_flags: Remove remoting feature flag (#29390)
This PR removes the `remoting` feature flag.

The feature is shipped, and we aren't referencing the flag anywhere
anymore.

Release Notes:

- N/A
2025-04-25 03:41:11 +00:00
Marshall Bowers
6bb6be826d language_models: Use POST /completions endpoint for Zed provider (#29389)
This PR updates the Zed provider to use the `POST /completions`
endpoint.

There is no functional difference from `POST /completion`, but the
pluralized version reads better.

Release Notes:

- N/A
2025-04-25 02:58:02 +00:00
Michael Sloan
7d9a55d101 Bring back reload of agent context before sending message (#29385)
Realized after merging #29233 that this behavior is desired

Release Notes:

- N/A
2025-04-24 20:32:53 -06:00
Max Brunsfeld
57d8397f53 Remove unnecessary fields from the tool schemas (#29381)
This PR removes two fields from JSON schemas (`$schema` and `title`),
which are not expected by any model provider, but were spuriously
included by our JSON schema library, `schemars`.

These added noise to requests and cost wasted input tokens.

### Old

```json
{
  "$schema": "http://json-schema.org/draft-07/schema#",
  "title": "FetchToolInput",
  "type": "object",
  "required": [
    "url"
  ],
  "properties": {
    "url": {
      "description": "The URL to fetch.",
      "type": "string"
    }
  }
}
```

### New:

```json
{
  "properties": {
    "url": {
      "description": "The URL to fetch.",
      "type": "string"
    }
  },
  "required": [
    "url"
  ],
  "type": "object"
}
```

- N/A
2025-04-24 18:09:25 -07:00
Michael Sloan
17ecf94f6f Restructure agent context (#29233)
Simplifies the data structures involved in agent context by removing
caching and limiting the use of ContextId:

* `AssistantContext` enum is now like an ID / handle to context that
does not need to be updated. `ContextId` still exists but is only used
for generating unique `ElementId`.
* `ContextStore` has a `IndexMap<ContextSetEntry>`. Only need to keep a
`HashSet<ThreadId>` consistent with it. `ContextSetEntry` is a newtype
wrapper around `AssistantContext` which implements eq / hash on a subset
of fields.
* Thread `Message` directly stores its context.

Fixes the following bugs:

* If a context entry is removed from the strip and added again, it was
reincluded in the next message.
* Clicking file context in the thread that has been removed from the
context strip didn't jump to the file.
* Refresh of directory context didn't reflect added / removed files.
* Deleted directories would remain in the message editor context strip.
* Token counting requests didn't include image context.
* File, directory, and symbol context deduplication relied on
`ProjectPath` for identity, and so didn't handle renames.
* Symbol context line numbers didn't update when shifted

Known bugs (not fixed):

* Deleting a directory causes it to disappear from messages in threads.
Fixing this in a nice way is tricky. One easy fix is to store the
original path and show that on deletion. It's weird that deletion would
cause the name to "revert", though. Another possibility would be to
snapshot context metadata on add (ala `AddedContext`), and keep that
around despite deletion.

Release Notes:

- N/A
2025-04-24 21:29:33 +00:00
Nathan Sobo
d492939bed Back off the eval to once a day for now (#29378)
cc @maxbrunsfeld 

Release Notes:

- N/A
2025-04-24 14:54:54 -06:00
Richard Feldman
720dfee803 Treat invalid JSON in tool calls as failed tool calls (#29375)
Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-24 16:54:27 -04:00
Anthony Eid
a98c648201 debugger: Fix spawned debug adapters taking over Zed's shell (#29373)
This fixes a bug where Zed wasn't closable via ctl-c in the shell it was
spawned in after starting a debug adapter

Release Notes:

- N/A

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-24 20:46:25 +00:00
Conrad Irwin
c147daae4a Terminal in debugger (#29328)
- **debug-terminal**
- **Use terminal inside debugger to spawn commands**

Closes #ISSUE

Release Notes:

- N/A
2025-04-24 14:26:09 -06:00
Smit Barmase
d3911e34de editor: Move blame popover from hover_tooltip to editor prepaint (#29320)
WIP!

In light of having more control over blame popover from editor.

This fixes: https://github.com/zed-industries/zed/issues/28645,
https://github.com/zed-industries/zed/issues/26304

- [x] Initial rendering
- [x] Handle smart positioning (edge detection, etc)
- [x] Delayed hovering, release, etc
- [x] Test blame message selection
- [x] Fix tagged issues

Release Notes:

- Git inline blame popover now dismisses when the cursor is moved, the
editor is scrolled, or the command palette is opened.
2025-04-25 01:52:24 +05:30
Danilo Leal
87f85f1863 Rename "Prompt Library" to "Rules Library" (#29349)
There's probably more to do to fully make the transition, and we'll
still debate a bit internally whether this is the name, but just opening
this PR up now for visibility.

Release Notes:

- N/A
2025-04-24 16:42:06 -03:00
Dan Dascalescu
1a4dab97db docs: Remove redundant word in "Configuring Zed" (#29364)
Release Notes:

- N/A
2025-04-24 21:42:19 +03:00
Bennet Bo Fenner
cd365b0cf5 gemini: Fix issue when deserializing tool call (#29363)
Fixes a regression introduced in #29322

Release Notes:

- N/A

Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-24 18:19:05 +00:00
Agus Zubiaga
58604fba86 agent: Do not reuse assistant message across generations (#29360)
#29354 introduced a bug where we would append tool uses to the last
assistant message even if it was from a previous request.

Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-24 17:56:47 +00:00
shenjack
b0609272c0 ollama: Add DeepSeek v3 max token length (#29156)
Add deepseek-v3 max token length for ollama

Release Notes:

- N/A
2025-04-24 13:20:22 -04:00
Peter Tripp
a17807d8b1 docs: Rust-analyzer example settings for alternate targets (#29353)
Release Notes:

- N/A
2025-04-24 13:12:41 -04:00
Agus Zubiaga
f81e65ae7c agent: Do not create user messages for tool results in thread (#29354)
We used to insert empty user messages into the `Thread::messages` `Vec`
when tools finished running and then we would attach the results when
creating the request. This approach was very easy to mess up during
state handling, leading to empty user messages displayed in the
conversation and API failures.

Instead, we will no longer insert actual user messages for tool results
to the `Thread`, and will only do this on the fly when creating the
model request. This simplifies a lot of code and show fix the mentioned
errors.

Release Notes:

- agent: Improve reliability of LLM requests when including tool results

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-04-24 16:30:15 +00:00
Marshall Bowers
952fe34aaa anthropic: Remove list of supported countries (#29346)
This PR removes the list of supported countries from the `anthropic`
crate, as it is no longer referenced in this repo.

Release Notes:

- N/A
2025-04-24 15:17:33 +00:00
Marshall Bowers
f527df6fa1 google_ai: Remove list of supported countries (#29348)
This PR removes the list of supported countries from the `google_ai`
crate, as it is no longer referenced in this repo.

Release Notes:

- N/A
2025-04-24 15:04:45 +00:00
Marshall Bowers
b54bbebc03 open_ai: Remove list of supported countries (#29347)
This PR removes the list of supported countries from the `open_ai`
crate, as it is no longer referenced in this repo.

Release Notes:

- N/A
2025-04-24 14:55:37 +00:00
Vojtěch Hořánek
8bb7a1f9e7 Remove linked_edits issue description from Elm doc (#29350)
The known issue with `linked_edits` seems to be fixed in this PR:
https://github.com/elm-tooling/elm-language-server/pull/1364. This PR
removes the section from Zeds documentation to avoid confusion.

Release Notes:

- Remove known issues section from Elm documentation.
2025-04-24 14:54:12 +00:00
Danilo Leal
e70d8d4dfd agent: Simplify user message design more (#29326)
Follow-up to https://github.com/zed-industries/zed/pull/29165 where the
user message design is simplified even more. The edit button is not
visible anymore, and you can click on the whole message block to edit a
message.

Release Notes:

- N/A
2025-04-24 11:24:36 -03:00
Marshall Bowers
ea5ce2a1a4 collab: Remove unused RateLimiter (#29343)
This PR removes the `RateLimiter` from the collab codebase, as it is no
longer used.

Release Notes:

- N/A
2025-04-24 14:23:17 +00:00
Peter Tripp
fd8eeb537d Fix ctrl-enter opening inline-assistant in assistant text threads (#29313)
Closes: https://github.com/zed-industries/zed/issues/24501

This has been broken for a while on linux (at least since Feb 8th!) for Assistant1.
It is also broken for Text Threads in Assitant2 (on macos and linux).

This should fix both.

Potentially related:
- https://github.com/zed-industries/zed/pull/29107

Release Notes:

- Fix for `ctrl-enter` shortcut in Assistant text threads incorrectly
opening inline assist instead of triggering Send.

Co-authored-by: Conrad Irwin <conrad@zed.dev>
2025-04-24 09:17:35 -04:00
Marshall Bowers
92f21ee39d collab: Return current plan based on subscription status (#29341)
This PR makes collab return the current plan based on subscription
status instead of based on the staff bit.

Release Notes:

- N/A
2025-04-24 13:04:25 +00:00
102 changed files with 3714 additions and 4329 deletions

View File

@@ -2,7 +2,7 @@ name: Run Agent Eval
on:
schedule:
- cron: "0 * * * *"
- cron: "0 0 * * *"
pull_request:
branches:

78
Cargo.lock generated
View File

@@ -61,7 +61,6 @@ dependencies = [
"buffer_diff",
"chrono",
"client",
"clock",
"collections",
"command_palette_hooks",
"component",
@@ -96,12 +95,13 @@ dependencies = [
"paths",
"picker",
"project",
"prompt_library",
"prompt_store",
"proto",
"rand 0.8.5",
"ref-cast",
"release_channel",
"rope",
"rules_library",
"schemars",
"serde",
"serde_json",
@@ -498,11 +498,11 @@ dependencies = [
"parking_lot",
"pretty_assertions",
"project",
"prompt_library",
"prompt_store",
"proto",
"rand 0.8.5",
"rope",
"rules_library",
"schemars",
"search",
"serde",
@@ -11083,32 +11083,6 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "prompt_library"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"editor",
"gpui",
"language",
"language_model",
"log",
"menu",
"picker",
"prompt_store",
"release_channel",
"rope",
"serde",
"settings",
"theme",
"ui",
"util",
"workspace",
"workspace-hack",
"zed_actions",
]
[[package]]
name = "prompt_store"
version = "0.1.0"
@@ -11742,6 +11716,26 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "ref-cast"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
dependencies = [
"ref-cast-impl",
]
[[package]]
name = "ref-cast-impl"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "refineable"
version = "0.1.0"
@@ -12271,6 +12265,32 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rules_library"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"editor",
"gpui",
"language",
"language_model",
"log",
"menu",
"picker",
"prompt_store",
"release_channel",
"rope",
"serde",
"settings",
"theme",
"ui",
"util",
"workspace",
"workspace-hack",
"zed_actions",
]
[[package]]
name = "runtimelib"
version = "0.25.0"

View File

@@ -109,7 +109,7 @@ members = [
"crates/project",
"crates/project_panel",
"crates/project_symbols",
"crates/prompt_library",
"crates/rules_library",
"crates/prompt_store",
"crates/proto",
"crates/recent_projects",
@@ -319,7 +319,7 @@ prettier = { path = "crates/prettier" }
project = { path = "crates/project" }
project_panel = { path = "crates/project_panel" }
project_symbols = { path = "crates/project_symbols" }
prompt_library = { path = "crates/prompt_library" }
rules_library = { path = "crates/rules_library" }
prompt_store = { path = "crates/prompt_store" }
proto = { path = "crates/proto" }
recent_projects = { path = "crates/recent_projects" }
@@ -500,6 +500,7 @@ prost-types = "0.9"
pulldown-cmark = { version = "0.12.0", default-features = false }
quote = "1.0.9"
rand = "0.8.5"
ref-cast = "1.0.24"
rayon = "1.8"
regex = "1.5"
repair_json = "0.1.0"

View File

@@ -212,7 +212,7 @@
"ctrl-shift-g": "search::SelectPreviousMatch",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-k h": "assistant::DeployHistory",
"ctrl-k l": "assistant::OpenPromptLibrary",
"ctrl-k l": "assistant::OpenRulesLibrary",
"new": "assistant::NewChat",
"ctrl-t": "assistant::NewChat",
"ctrl-n": "assistant::NewChat"
@@ -241,7 +241,7 @@
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "assistant::OpenPromptLibrary",
"ctrl-alt-p": "assistant::OpenRulesLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
@@ -308,9 +308,9 @@
{
"context": "PromptLibrary",
"bindings": {
"new": "prompt_library::NewPrompt",
"ctrl-n": "prompt_library::NewPrompt",
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
"new": "rules_library::NewRule",
"ctrl-n": "rules_library::NewRule",
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
}
},
{
@@ -675,7 +675,7 @@
}
},
{
"context": "Editor && mode == full",
"context": "!ContextEditor > Editor && mode == full",
"bindings": {
"alt-enter": "editor::OpenExcerpts",
"shift-enter": "editor::ExpandExcerpts",

View File

@@ -5,8 +5,8 @@
"context": "PromptLibrary",
"use_key_equivalents": true,
"bindings": {
"cmd-n": "prompt_library::NewPrompt",
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
"cmd-n": "rules_library::NewRule",
"cmd-shift-s": "rules_library::ToggleDefaultRule",
"cmd-w": "workspace::CloseWindow"
}
},
@@ -257,7 +257,7 @@
"cmd-shift-g": "search::SelectPreviousMatch",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-k h": "assistant::DeployHistory",
"cmd-k l": "assistant::OpenPromptLibrary",
"cmd-k l": "assistant::OpenRulesLibrary",
"cmd-t": "assistant::NewChat",
"cmd-n": "assistant::NewChat"
}
@@ -286,7 +286,7 @@
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "assistant::OpenPromptLibrary",
"cmd-alt-p": "assistant::OpenRulesLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
@@ -738,7 +738,7 @@
}
},
{
"context": "Editor && mode == full",
"context": "!ContextEditor > Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"alt-enter": "editor::OpenExcerpts",

View File

@@ -1,2 +1,7 @@
allow-private-module-inception = true
avoid-breaking-exported-api = false
ignore-interior-mutability = [
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
# and Hash impls do not use fields with interior mutability.
"agent::context::AgentContextKey"
]

View File

@@ -28,7 +28,6 @@ async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -62,9 +61,10 @@ parking_lot.workspace = true
paths.workspace = true
picker.workspace = true
project.workspace = true
prompt_library.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
ref-cast.workspace = true
release_channel.workspace = true
rope.workspace = true
schemars.workspace = true

View File

@@ -1,4 +1,4 @@
use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
use crate::context::{AgentContext, RULES_ICON};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@@ -25,8 +25,8 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role,
StopReason,
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
RequestUsage, Role, StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@@ -43,16 +43,14 @@ use ui::{
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownString;
use workspace::{OpenOptions, Workspace};
use zed_actions::assistant::OpenPromptLibrary;
use crate::context_store::ContextStore;
use zed_actions::assistant::OpenRulesLibrary;
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<Thread>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
@@ -695,7 +693,7 @@ fn open_markdown_link(
}),
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
Box::new(OpenPromptLibrary {
Box::new(OpenRulesLibrary {
prompt_to_select: Some(prompt_id.0),
}),
cx,
@@ -716,7 +714,6 @@ impl ActiveThread {
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
language_registry: Arc<LanguageRegistry>,
context_store: Entity<ContextStore>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
@@ -739,7 +736,6 @@ impl ActiveThread {
language_registry,
thread_store,
thread: thread.clone(),
context_store,
workspace,
save_thread_task: None,
messages: Vec::new(),
@@ -769,7 +765,7 @@ impl ActiveThread {
this.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
tool_use.status.text(),
cx,
);
@@ -779,10 +775,6 @@ impl ActiveThread {
this
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
pub fn thread(&self) -> &Entity<Thread> {
&self.thread
}
@@ -870,7 +862,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_label: impl Into<SharedString>,
tool_input: &serde_json::Value,
tool_input: &str,
tool_output: SharedString,
cx: &mut Context<Self>,
) {
@@ -893,11 +885,10 @@ impl ActiveThread {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
let input = format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
this.replace(
MarkdownString::code_block("json", tool_input).to_string(),
cx,
);
this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
@@ -988,7 +979,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
"".into(),
cx,
);
@@ -1002,7 +993,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
input,
&serde_json::to_string_pretty(&input).unwrap_or_default(),
"".into(),
cx,
);
@@ -1014,7 +1005,7 @@ impl ActiveThread {
self.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
&tool_use.input,
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
@@ -1026,6 +1017,23 @@ impl ActiveThread {
}
ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {}
ThreadEvent::InvalidToolInput {
tool_use_id,
ui_text,
invalid_input_json,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text,
invalid_input_json,
self.thread
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
.unwrap_or("".into()),
cx,
);
}
}
}
@@ -1256,26 +1264,36 @@ impl ActiveThread {
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = thread.read(cx).context_for_message(message_id);
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let Some(message) = thread.read(cx).message(message_id) else {
log::error!("Message that was being edited no longer exists");
return None;
};
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
if message_text.is_empty() && message.loaded_context.is_empty() {
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
message
.loaded_context
.add_to_request_message(&mut request_message);
if !message_text.is_empty() {
request_message
.content
.push(MessageContent::Text(message_text));
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
@@ -1470,13 +1488,21 @@ impl ActiveThread {
return Empty.into_any();
};
let context_store = self.context_store.clone();
let workspace = self.workspace.clone();
let thread = self.thread.read(cx);
let prompt_store = self.thread_store.read(cx).prompt_store().as_ref();
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id).collect::<Vec<_>>();
let added_context = if let Some(workspace) = workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
thread
.context_for_message(message_id)
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
.collect::<Vec<_>>()
} else {
return Empty.into_any();
};
let tool_uses = thread.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty();
@@ -1485,41 +1511,13 @@ impl ActiveThread {
let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1;
let show_feedback = (!is_generating && is_last_message && message.role != Role::User)
|| self.messages.get(ix + 1).map_or(false, |next_id| {
self.thread
.read(cx)
.message(*next_id)
.map_or(false, |next_message| {
next_message.role == Role::User
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
&& thread.tool_results_for_message(*next_id).is_empty()
})
});
let show_feedback = thread.is_turn_end(ix);
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User && thread.message_has_tool_results(message_id) {
if let Some(generating_label) = generating_label {
return h_flex()
.w_full()
.h_10()
.py_1p5()
.pl_4()
.pb_3()
.child(generating_label)
.into_any_element();
}
return Empty.into_any();
}
let allow_editing_message = message.role == Role::User;
let edit_message_editor = self
.editing_message
.as_ref()
@@ -1652,90 +1650,78 @@ impl ActiveThread {
};
let message_is_empty = message.should_display_content();
let has_content = !message_is_empty || !context.is_empty();
let has_content = !message_is_empty || !added_context.is_empty();
let message_content =
has_content.then(|| {
v_flex()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let message_content = has_content.then(|| {
v_flex()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small.rems(cx);
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
context.into_iter().map(|context| {
let context_id = context.id();
ContextPill::added(
AddedContext::new(context, cx),
false,
false,
None,
)
.on_click(Rc::new(cx.listener({
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
background: colors.editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
))
.into_any()
} else {
div()
.min_h_6()
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
))
.into_any()
},
)
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
let context = added_context.context.clone();
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
cx.listener({
let workspace = workspace.clone();
let context_store = context_store.clone();
move |_, _, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_context(
context_id,
context_store.clone(),
workspace,
window,
cx,
);
open_context(&context, workspace, window, cx);
cx.notify();
}
}
})))
}),
))
})
});
}),
))
}),
))
})
});
let styled_message = match message.role {
Role::User => v_flex()
@@ -1752,93 +1738,70 @@ impl ActiveThread {
.pb_4()
.child(
v_flex()
.id(("user-message", ix))
.bg(editor_bg_color)
.rounded_lg()
.shadow_md()
.border_1()
.border_color(colors.border)
.shadow_md()
.child(div().p_2().children(message_content))
.child(
h_flex()
.p_1()
.border_t_1()
.border_color(colors.border_variant)
.justify_end()
.child(
h_flex()
.gap_1()
.when_some(
edit_message_editor.clone(),
|this, edit_message_editor| {
let focus_handle =
edit_message_editor.focus_handle(cx);
this.child(
Button::new("cancel-edit-message", "Cancel")
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Cancel,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(
cx.listener(Self::handle_cancel_click),
),
.hover(|hover| hover.border_color(colors.text_accent.opacity(0.5)))
.cursor_pointer()
.child(div().p_2().pt_2p5().children(message_content))
.when_some(edit_message_editor.clone(), |this, edit_editor| {
let focus_handle = edit_editor.focus_handle(cx);
this.child(
h_flex()
.p_1()
.border_t_1()
.border_color(colors.border_variant)
.gap_1()
.justify_end()
.child(
Button::new("cancel-edit-message", "Cancel")
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Cancel,
&focus_handle,
window,
cx,
)
.child(
Button::new(
"confirm-edit-message",
"Regenerate",
)
.disabled(
edit_message_editor.read(cx).is_empty(cx),
)
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Confirm,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(
cx.listener(Self::handle_regenerate_click),
),
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(Self::handle_cancel_click)),
)
.child(
Button::new("confirm-edit-message", "Regenerate")
.disabled(edit_editor.read(cx).is_empty(cx))
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Confirm,
&focus_handle,
window,
cx,
)
},
)
.when(
edit_message_editor.is_none() && allow_editing_message,
|this| {
this.child(
Button::new("edit-message", "Edit Message")
.label_size(LabelSize::Small)
.icon(IconName::Pencil)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(cx.listener({
let message_segments =
message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
&message_segments,
window,
cx,
);
}
})),
)
},
),
),
),
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(cx.listener(Self::handle_regenerate_click)),
),
)
})
.when(edit_message_editor.is_none(), |this| {
this.tooltip(Tooltip::text("Click To Edit"))
})
.on_click(cx.listener({
let message_segments = message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
&message_segments,
window,
cx,
);
}
})),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
@@ -2995,11 +2958,11 @@ impl ActiveThread {
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
// TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(move |_event, window, cx| {
window.dispatch_action(
Box::new(OpenPromptLibrary {
Box::new(OpenRulesLibrary {
prompt_to_select: first_user_rules_id,
}),
cx,
@@ -3207,20 +3170,14 @@ impl Render for ActiveThread {
}
pub(crate) fn open_context(
id: ContextId,
context_store: Entity<ContextStore>,
context: &AgentContext,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut App,
) {
let Some(context) = context_store.read(cx).context_for_id(id) else {
return;
};
match context {
AssistantContext::File(file_context) => {
if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx)
{
AgentContext::File(file_context) => {
if let Some(project_path) = file_context.project_path(cx) {
workspace.update(cx, |workspace, cx| {
workspace
.open_path(project_path, None, true, window, cx)
@@ -3228,7 +3185,8 @@ pub(crate) fn open_context(
});
}
}
AssistantContext::Directory(directory_context) => {
AgentContext::Directory(directory_context) => {
let entry_id = directory_context.entry_id;
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |_project, cx| {
@@ -3236,61 +3194,51 @@ pub(crate) fn open_context(
})
})
}
AssistantContext::Symbol(symbol_context) => {
if let Some(project_path) = symbol_context
.context_symbol
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot();
let target_position = symbol_context
.context_symbol
.id
.range
.start
.to_point(&snapshot);
AgentContext::Symbol(symbol_context) => {
let buffer = symbol_context.buffer.read(cx);
if let Some(project_path) = buffer.project_path(cx) {
let snapshot = buffer.snapshot();
let target_position = symbol_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
AssistantContext::Selection(selection_context) => {
if let Some(project_path) = selection_context
.context_buffer
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot();
AgentContext::Selection(selection_context) => {
let buffer = selection_context.buffer.read(cx);
if let Some(project_path) = buffer.project_path(cx) {
let snapshot = buffer.snapshot();
let target_position = selection_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
}
}
AssistantContext::FetchedUrl(fetched_url_context) => {
AgentContext::FetchedUrl(fetched_url_context) => {
cx.open_url(&fetched_url_context.url);
}
AssistantContext::Thread(thread_context) => {
let thread_id = thread_context.thread.read(cx).id().clone();
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel.update(cx, |panel, cx| {
panel
.open_thread(&thread_id, window, cx)
.detach_and_log_err(cx)
});
}
})
}
AssistantContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenPromptLibrary {
AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel.update(cx, |panel, cx| {
let thread_id = thread_context.thread.read(cx).id().clone();
panel
.open_thread(&thread_id, window, cx)
.detach_and_log_err(cx)
});
}
}),
AgentContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenRulesLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
AssistantContext::Image(_) => {}
AgentContext::Image(_) => {}
}
}

View File

@@ -962,11 +962,13 @@ mod tests {
})
.unwrap();
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View File

@@ -39,6 +39,7 @@ use thread::ThreadId;
pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;

View File

@@ -24,9 +24,9 @@ use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library};
use settings::{Settings, update_settings_file};
use time::UtcOffset;
use ui::{
@@ -36,7 +36,7 @@ use util::ResultExt as _;
use workspace::Workspace;
use workspace::dock::{DockPosition, Panel, PanelEvent};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
@@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
}
})
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_rules_library(action, window, cx)
});
}
})
@@ -189,6 +189,7 @@ pub struct AssistantPanel {
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
local_timezone: UtcOffset,
@@ -205,14 +206,25 @@ impl AssistantPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext,
mut cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
cx.spawn(async move |cx| {
let prompt_store = match prompt_store {
Ok(prompt_store) => prompt_store.await.ok(),
Err(_) => None,
};
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
ThreadStore::load(
project,
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
cx,
)
})?
.await?;
@@ -230,7 +242,16 @@ impl AssistantPanel {
.await?;
workspace.update_in(cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
cx.new(|cx| {
Self::new(
workspace,
thread_store,
context_store,
prompt_store,
window,
cx,
)
})
})
})
}
@@ -239,6 +260,7 @@ impl AssistantPanel {
workspace: &Workspace,
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -262,6 +284,7 @@ impl AssistantPanel {
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@@ -293,7 +316,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
message_editor_context_store.clone(),
workspace.clone(),
window,
cx,
@@ -322,6 +344,7 @@ impl AssistantPanel {
message_editor_subscription,
],
context_store,
prompt_store,
configuration: None,
configuration_subscription: None,
local_timezone: UtcOffset::from_whole_seconds(
@@ -355,6 +378,10 @@ impl AssistantPanel {
self.local_timezone
}
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
&self.thread_store
}
@@ -411,7 +438,6 @@ impl AssistantPanel {
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
message_editor_context_store.clone(),
self.workspace.clone(),
window,
cx,
@@ -430,6 +456,7 @@ impl AssistantPanel {
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
window,
@@ -484,13 +511,13 @@ impl AssistantPanel {
context_editor.focus_handle(cx).focus(window);
}
fn deploy_prompt_library(
fn deploy_rules_library(
&mut self,
action: &OpenPromptLibrary,
action: &OpenRulesLibrary,
_window: &mut Window,
cx: &mut Context<Self>,
) {
open_prompt_library(
open_rules_library(
self.language_registry.clone(),
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
Arc::new(|| {
@@ -500,9 +527,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -598,7 +625,6 @@ impl AssistantPanel {
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
message_editor_context_store.clone(),
this.workspace.clone(),
window,
cx,
@@ -617,6 +643,7 @@ impl AssistantPanel {
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.prompt_store.clone(),
this.thread_store.downgrade(),
thread,
window,
@@ -1120,7 +1147,7 @@ impl AssistantPanel {
"New Text Thread",
NewTextThread.boxed_clone(),
)
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.separator()
.header("MCPs")
@@ -1833,7 +1860,7 @@ impl Render for AssistantPanel {
this.open_configuration(window, cx);
}))
.on_action(cx.listener(Self::open_active_thread_as_markdown))
.on_action(cx.listener(Self::deploy_prompt_library))
.on_action(cx.listener(Self::deploy_rules_library))
.on_action(cx.listener(Self::open_agent_diff))
.on_action(cx.listener(Self::go_back))
.child(self.render_toolbar(window, cx))
@@ -1860,13 +1887,13 @@ impl PromptLibraryInlineAssist {
}
}
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
fn assist(
&self,
prompt_editor: &Entity<Editor>,
_initial_prompt: Option<String>,
window: &mut Window,
cx: &mut Context<PromptLibrary>,
cx: &mut Context<RulesLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
let Some(project) = self
@@ -1876,11 +1903,14 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
else {
return;
};
let prompt_store = None;
let thread_store = None;
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
None,
prompt_store,
thread_store,
window,
cx,
)
@@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if panel.has_active_thread() {
panel.thread.update(cx, |thread, cx| {
thread.context_store().update(cx, |store, cx| {
panel.message_editor.update(cx, |message_editor, cx| {
message_editor.context_store().update(cx, |store, cx| {
let buffer = buffer.read(cx);
let selection_ranges = selection_ranges
.into_iter()
@@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
store.add_selection(buffer, range, cx);
}
})
})

View File

@@ -1,12 +1,14 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::context::ContextLoadResult;
use crate::inline_prompt_editor::CodegenStatus;
use crate::{context::load_context, context_store::ContextStore};
use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
@@ -14,7 +16,9 @@ use language_model::{
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use rope::Rope;
use smol::future::FutureExt;
use std::{
@@ -39,6 +43,8 @@ pub struct BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
pub is_insertion: bool,
@@ -50,6 +56,8 @@ impl BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -60,6 +68,8 @@ impl BufferCodegen {
range.clone(),
false,
Some(context_store.clone()),
project.clone(),
prompt_store.clone(),
Some(telemetry.clone()),
builder.clone(),
cx,
@@ -75,6 +85,8 @@ impl BufferCodegen {
range,
initial_transaction_id,
context_store,
project,
prompt_store,
telemetry,
builder,
};
@@ -153,6 +165,8 @@ impl BufferCodegen {
self.range.clone(),
false,
Some(self.context_store.clone()),
self.project.clone(),
self.prompt_store.clone(),
Some(self.telemetry.clone()),
self.builder.clone(),
cx,
@@ -229,13 +243,14 @@ pub struct CodegenAlternative {
generation: Task<()>,
diff: Diff,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
active: bool,
edits: Vec<(Range<Anchor>, String)>,
line_operations: Vec<LineOperation>,
request: Option<LanguageModelRequest>,
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
@@ -249,6 +264,8 @@ impl CodegenAlternative {
range: Range<Anchor>,
active: bool,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -290,6 +307,8 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
context_store,
project,
prompt_store,
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
@@ -297,7 +316,6 @@ impl CodegenAlternative {
edits: Vec::new(),
line_operations: Vec::new(),
range,
request: None,
elapsed_time: None,
completion: None,
}
@@ -366,16 +384,18 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(user_prompt, cx)?;
self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())
}
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
fn build_request(
&self,
user_prompt: String,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -408,30 +428,44 @@ impl CodegenAlternative {
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
let context_task = self.context_store.as_ref().map(|context_store| {
if let Some(project) = self.project.upgrade() {
let context = context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
load_context(context, &project, &self.prompt_store, cx)
} else {
Task::ready(ContextLoadResult::default())
}
});
if let Some(context_store) = &self.context_store {
attach_context_to_message(
&mut request_message,
context_store.read(cx).context().iter(),
cx,
);
}
Ok(cx.spawn(async move |_cx| {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
request_message.content.push(prompt.into());
if let Some(context_task) = context_task {
context_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
}
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
}
}))
}
pub fn handle_stream(
@@ -508,7 +542,9 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -1034,6 +1070,7 @@ impl Diff {
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use futures::{
Stream,
stream::{self},
@@ -1076,12 +1113,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1140,12 +1181,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1207,12 +1252,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1274,12 +1323,16 @@ mod tests {
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1329,12 +1382,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
false,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,

File diff suppressed because it is too large Load Diff

View File

@@ -10,8 +10,11 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
use gpui::{
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
@@ -20,10 +23,10 @@ use gpui::{
use language::Buffer;
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use prompt_store::{PromptStore, UserPromptId};
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
@@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -193,6 +192,13 @@ impl ContextPicker {
)
.collect::<Vec<Subscription>>();
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
window,
@@ -202,6 +208,7 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
prompt_store,
_subscriptions: subscriptions,
}
}
@@ -226,7 +233,12 @@ impl ContextPicker {
.workspace
.upgrade()
.map(|workspace| {
available_context_picker_entries(&self.thread_store, &workspace, cx)
available_context_picker_entries(
&self.prompt_store,
&self.thread_store,
&workspace,
cx,
)
})
.unwrap_or_default();
@@ -304,10 +316,10 @@ impl ContextPicker {
}));
}
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
if let Some(prompt_store) = self.prompt_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
prompt_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
@@ -526,6 +538,7 @@ enum RecentEntry {
}
fn available_context_picker_entries(
prompt_store: &Option<Entity<PromptStore>>,
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
cx: &mut App,
@@ -550,6 +563,9 @@ fn available_context_picker_entries(
if thread_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
}
if prompt_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
}
@@ -585,22 +601,21 @@ fn recent_context_picker_entries(
}),
);
let mut current_threads = context_store.read(cx).thread_ids();
let current_threads = context_store.read(cx).thread_ids();
if let Some(active_thread) = workspace
let active_thread_id = workspace
.panel::<AssistantPanel>(cx)
.map(|panel| panel.read(cx).active_thread(cx))
{
current_threads.insert(active_thread.read(cx).id().clone());
}
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
recent.extend(
thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.filter(|thread| !current_threads.contains(&thread.id))
.filter(|thread| {
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
})
.take(2)
.map(|thread| {
RecentEntry::Thread(ThreadContextEntry {
@@ -622,9 +637,7 @@ fn add_selections_as_context(
let selection_ranges = selection_ranges(workspace, cx);
context_store.update(cx, |context_store, cx| {
for (buffer, range) in selection_ranges {
context_store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
context_store.add_selection(buffer, range, cx);
}
})
}

View File

@@ -15,22 +15,21 @@ use itertools::Itertools;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch;
use super::file_context_picker::{FileMatch, search_files};
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
use super::symbol_context_picker::search_symbols;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
@@ -38,8 +37,8 @@ use super::{
};
pub(crate) enum Match {
Symbol(SymbolMatch),
File(FileMatch),
Symbol(SymbolMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
@@ -69,6 +68,7 @@ fn search(
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
@@ -85,6 +85,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
@@ -96,6 +97,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
@@ -111,6 +113,7 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
@@ -118,10 +121,11 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
if let Some(prompt_store) = prompt_store.as_ref() {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
@@ -133,6 +137,7 @@ fn search(
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
@@ -163,7 +168,7 @@ fn search(
.collect::<Vec<_>>();
matches.extend(
available_context_picker_entries(&thread_store, &workspace, cx)
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
@@ -180,7 +185,8 @@ fn search(
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
let entries =
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
let entry_candidates = entries
.iter()
.enumerate()
@@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
move |_, _: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
context_store
.add_selection(buffer.clone(), range.clone(), cx)
.detach_and_log_err(cx)
context_store.add_selection(
buffer.clone(),
range.clone(),
cx,
);
}
});
@@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
@@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
let user_prompt_id = rules.prompt_id;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx);
});
},
)),
}
@@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
let url_to_fetch = url_to_fetch.clone();
cx.spawn(async move |cx| {
if context_store.update(cx, |context_store, _| {
context_store.includes_url(&url_to_fetch).is_some()
context_store.includes_url(&url_to_fetch)
})? {
return Ok(());
}
@@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
move |cx| {
context_store.update(cx, |context_store, cx| {
let task = if is_directory {
context_store.add_directory(project_path.clone(), false, cx)
Task::ready(context_store.add_directory(&project_path, false, cx))
} else {
context_store.add_file_from_path(project_path.clone(), false, cx)
};
@@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
);
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
prompt_store,
thread_store.clone(),
workspace.clone(),
cx,
@@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
))
}
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
@@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
workspace.clone(),
cx,
),
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
@@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Rules(user_rules) => Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
)),
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
@@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,

View File

@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).includes_url(&self.url).is_some()
context_store.read(cx).includes_url(&self.url)
});
Some(

View File

@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
.context_store
.update(cx, |context_store, cx| {
if is_directory {
context_store.add_directory(project_path, true, cx)
Task::ready(context_store.add_directory(&project_path, true, cx))
} else {
context_store.add_file_from_path(project_path, true, cx)
context_store.add_file_from_path(project_path.clone(), true, cx)
}
})
.ok()
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
path: path.clone(),
};
if is_directory {
context_store.read(cx).includes_directory(&project_path)
} else {
context_store
.read(cx)
.will_include_file_path(&project_path, cx)
.path_included_in_directory(&project_path, cx)
} else {
context_store.read(cx).file_path_included(&project_path, cx)
}
});
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
})),
)
.when_some(added, |el, added| match added {
FileInclusion::Direct(_) => el.child(
FileInclusion::Direct => el.child(
h_flex()
.w_full()
.justify_end()
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
FileInclusion::InDirectory(directory_project_path) => {
// TODO: Consider using worktree full_path to include worktree name.
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
FileInclusion::InDirectory { full_path } => {
let directory_full_path = full_path.to_string_lossy().into_owned();
el.child(
h_flex()
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
.tooltip(Tooltip::text(format!("in {directory_path}")))
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
}
})
}

View File

@@ -1,16 +1,15 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use prompt_store::{PromptId, PromptStore, UserPromptId};
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
@@ -18,13 +17,13 @@ pub struct RulesContextPicker {
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
@@ -50,7 +49,7 @@ pub struct RulesContextEntry {
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
@@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
prompt_store,
context_picker,
context_store,
matches: Vec::new(),
@@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
let search_task = search_rules(
query,
Arc::new(AtomicBool::default()),
&self.prompt_store,
cx,
);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
self.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(entry.prompt_id, true, cx)
})
})
.detach_and_log_err(cx);
.log_err();
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
@@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
let added = context_store.upgrade().map_or(false, |context_store| {
context_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
.includes_user_rules(user_rules.prompt_id)
});
h_flex()
@@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
prompt_store: &Entity<PromptStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task

View File

@@ -10,7 +10,6 @@ use gpui::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use project::{DocumentSymbol, Symbol};
use text::OffsetRangeExt;
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
@@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
)
})?;
context_store
.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})?
.await
context_store.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})
})
}
@@ -353,38 +350,13 @@ fn compute_symbol_entries(
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
let mut symbol_entries = Vec::with_capacity(symbols.len());
for SymbolMatch { symbol, .. } in symbols {
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
let is_included = if let Some(symbols_for_path) = symbols_for_path {
let mut is_included = false;
for included_symbol_id in symbols_for_path {
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
let snapshot = buffer.read(cx).snapshot();
let included_symbol_range =
included_symbol_id.range.to_point_utf16(&snapshot);
if included_symbol_range.start == symbol.range.start.0
&& included_symbol_range.end == symbol.range.end.0
{
is_included = true;
break;
}
}
}
}
is_included
} else {
false
};
symbol_entries.push(SymbolEntry {
symbols
.into_iter()
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
is_included: context_store.includes_symbol(&symbol, cx),
symbol,
is_included,
})
}
symbol_entries
.collect::<Vec<_>>()
}
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {

View File

@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store.read(cx).includes_thread(&thread.id).is_some()
ctx_store.read(cx).includes_thread(&thread.id)
});
h_flex()
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,

File diff suppressed because it is too large Load Diff

View File

@@ -12,9 +12,9 @@ use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::Workspace;
use crate::context::{ContextId, ContextKind};
use crate::context::{AgentContext, ContextKind};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
@@ -32,6 +32,7 @@ pub struct ContextStrip {
focus_handle: FocusHandle,
suggest_context_kind: SuggestContextKind,
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
_subscriptions: Vec<Subscription>,
focused_index: Option<usize>,
children_bounds: Option<Vec<Bounds<Pixels>>>,
@@ -73,12 +74,31 @@ impl ContextStrip {
focus_handle,
suggest_context_kind,
workspace,
thread_store,
_subscriptions: subscriptions,
focused_index: None,
children_bounds: None,
}
}
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
if let Some(workspace) = self.workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
let prompt_store = self
.thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
self.context_store
.read(cx)
.context()
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
.collect::<Vec<_>>()
} else {
Vec::new()
}
}
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
match self.suggest_context_kind {
SuggestContextKind::File => self.suggested_file(cx),
@@ -93,22 +113,19 @@ impl ContextStrip {
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), &project_path)
.file_path_included(&project_path, cx)
.is_some()
{
return None;
}
let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
Some(SuggestedContext::File {
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
@@ -135,7 +152,6 @@ impl ContextStrip {
.context_store
.read(cx)
.includes_thread(active_thread.id())
.is_some()
{
return None;
}
@@ -272,12 +288,12 @@ impl ContextStrip {
best.map(|(index, _, _)| index)
}
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
crate::active_thread::open_context(context, workspace, window, cx);
}
fn remove_focused_context(
@@ -287,17 +303,17 @@ impl ContextStrip {
cx: &mut Context<Self>,
) {
if let Some(index) = self.focused_index {
let mut is_empty = false;
let added_contexts = self.added_contexts(cx);
let Some(context) = added_contexts.get(index) else {
return;
};
self.context_store.update(cx, |this, cx| {
if let Some(item) = this.context().get(index) {
this.remove_context(item.id(), cx);
}
is_empty = this.context().is_empty();
this.remove_context(&context.context, cx);
});
if is_empty {
let is_now_empty = added_contexts.len() == 1;
if is_now_empty {
cx.emit(ContextStripEvent::BlurredEmpty);
} else {
self.focused_index = Some(index.saturating_sub(1));
@@ -306,49 +322,28 @@ impl ContextStrip {
}
}
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
// We only suggest one item after the actual context
self.focused_index == Some(context.len())
self.focused_index == Some(added_contexts.len())
}
fn accept_suggested_context(
&mut self,
_: &AcceptSuggestedContext,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(suggested) = self.suggested_context(cx) {
let context_store = self.context_store.read(cx);
if self.is_suggested_focused(context_store.context()) {
self.add_suggested_context(&suggested, window, cx);
if self.is_suggested_focused(&self.added_contexts(cx)) {
self.add_suggested_context(&suggested, cx);
}
}
}
fn add_suggested_context(
&mut self,
suggested: &SuggestedContext,
window: &mut Window,
cx: &mut Context<Self>,
) {
let task = self.context_store.update(cx, |context_store, cx| {
context_store.accept_suggested_context(&suggested, cx)
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
self.context_store.update(cx, |context_store, cx| {
context_store.add_suggested_context(&suggested, cx)
});
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => {}
Some(()) => {
if let Some(this) = this.upgrade() {
this.update(cx, |_, cx| cx.notify())?;
}
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
cx.notify();
}
}
@@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let context_store = self.context_store.read(cx);
let context = context_store.context();
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
let suggested_context = self.suggested_context(cx);
let added_contexts = context
.iter()
.map(|c| AddedContext::new(c, cx))
.collect::<Vec<_>>();
let added_contexts = self.added_contexts(cx);
let dupe_names = added_contexts
.iter()
.map(|c| c.name.clone())
@@ -380,6 +368,14 @@ impl Render for ContextStrip {
.filter(|(a, b)| a == b)
.map(|(a, _)| a)
.collect::<HashSet<SharedString>>();
let no_added_context = added_contexts.is_empty();
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
(
suggested_context,
self.is_suggested_focused(&added_contexts),
)
});
h_flex()
.flex_wrap()
@@ -436,7 +432,7 @@ impl Render for ContextStrip {
})
.with_handle(self.context_picker_menu_handle.clone()),
)
.when(context.is_empty() && suggested_context.is_none(), {
.when(no_added_context && suggested_context.is_none(), {
|parent| {
parent.child(
h_flex()
@@ -466,16 +462,17 @@ impl Render for ContextStrip {
.enumerate()
.map(|(i, added_context)| {
let name = added_context.name.clone();
let id = added_context.id;
let context = added_context.context.clone();
ContextPill::added(
added_context,
dupe_names.contains(&name),
self.focused_index == Some(i),
Some({
let context = context.clone();
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, cx| {
this.remove_context(id, cx);
this.remove_context(&context, cx);
});
cx.notify();
}))
@@ -484,7 +481,7 @@ impl Render for ContextStrip {
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
this.open_context(id, window, cx);
this.open_context(&context, window, cx);
} else {
this.focused_index = Some(i);
}
@@ -493,22 +490,22 @@ impl Render for ContextStrip {
})
}),
)
.when_some(suggested_context, |el, suggested| {
.when_some(suggested_context, |el, (suggested, focused)| {
el.child(
ContextPill::suggested(
suggested.name().clone(),
suggested.icon_path(),
suggested.kind(),
self.is_suggested_focused(&context),
focused,
)
.on_click(Rc::new(cx.listener(
move |this, _event, window, cx| {
this.add_suggested_context(&suggested, window, cx);
move |this, _event, _window, cx| {
this.add_suggested_context(&suggested, cx);
},
))),
)
})
.when(!context.is_empty(), {
.when(!no_added_context, {
move |parent| {
parent.child(
IconButton::new("remove-all-context", IconName::Eraser)
@@ -534,6 +531,7 @@ impl Render for ContextStrip {
)
}
})
.into_any()
}
}

View File

@@ -51,7 +51,10 @@ impl HistoryStore {
return history_entries;
}
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
for thread in self
.thread_store
.update(cx, |this, _cx| this.reverse_chronological_threads())
{
history_entries.push(HistoryEntry::Thread(thread));
}

View File

@@ -32,6 +32,7 @@ use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use settings::{Settings, SettingsStore};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
@@ -245,9 +246,13 @@ impl InlineAssistant {
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
let assistant_panel = workspace
.panel::<AssistantPanel>(cx)
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
.map(|assistant_panel| assistant_panel.read(cx));
let prompt_store = assistant_panel
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
let thread_store =
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
let handle_assist =
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
@@ -257,6 +262,7 @@ impl InlineAssistant {
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@@ -269,6 +275,7 @@ impl InlineAssistant {
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@@ -323,6 +330,7 @@ impl InlineAssistant {
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -437,6 +445,8 @@ impl InlineAssistant {
range.clone(),
None,
context_store.clone(),
project.clone(),
prompt_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -525,6 +535,7 @@ impl InlineAssistant {
initial_transaction_id: Option<TransactionId>,
focus: bool,
workspace: Entity<Workspace>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -543,7 +554,7 @@ impl InlineAssistant {
}
let project = workspace.read(cx).project().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
let codegen = cx.new(|cx| {
BufferCodegen::new(
@@ -551,6 +562,8 @@ impl InlineAssistant {
range.clone(),
initial_transaction_id,
context_store.clone(),
project,
prompt_store,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
let prompt_store = PromptStore::global(cx);
window.spawn(cx, async move |cx| {
let workspace = workspace.upgrade().context("workspace was released")?;
let editor = editor.upgrade().context("editor was released")?;
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
})?
.context("invalid range")?;
let prompt_store = prompt_store.await.ok();
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
let assist_id = assistant.suggest_assist(
&editor,
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
None,
true,
workspace,
prompt_store,
thread_store,
window,
cx,

View File

@@ -13,7 +13,7 @@ use editor::{
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
actions::{MoveDown, MoveUp},
};
use feature_flags::{FeatureFlagAppExt as _, ZedPro};
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
use fs::Fs;
use gpui::{
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
@@ -132,7 +132,7 @@ impl<T: 'static> Render for PromptEditor<T> {
let error_message = SharedString::from(error.to_string());
if error.error_code() == proto::ErrorCode::RateLimitExceeded
&& cx.has_flag::<ZedPro>()
&& cx.has_flag::<ZedProFeatureFlag>()
{
el.child(
v_flex()
@@ -931,7 +931,7 @@ impl PromptEditor<BufferCodegen> {
.update(cx, |editor, _| editor.set_read_only(false));
}
CodegenStatus::Error(error) => {
if cx.has_flag::<ZedPro>()
if cx.has_flag::<ZedProFeatureFlag>()
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
&& !dismissed_rate_limit_notice()
{

View File

@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use crate::context::{AssistantContext, format_context_as_string};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
@@ -13,6 +13,8 @@ use editor::{
};
use file_icons::FileIcons;
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt as _, future};
use gpui::{
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
@@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@@ -31,7 +34,7 @@ use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
@@ -49,9 +52,12 @@ pub struct MessageEditor {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
last_loaded_context: Option<ContextLoadResult>,
context_load_task: Option<Shared<Task<()>>>,
profile_selector: Entity<ProfileSelector>,
edits_expanded: bool,
editor_is_expanded: bool,
@@ -68,6 +74,7 @@ impl MessageEditor {
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
@@ -135,13 +142,11 @@ impl MessageEditor {
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.message_or_context_changed(true, cx);
}
EditorEvent::BufferEdited => this.handle_message_changed(cx),
_ => {}
}),
cx.observe(&context_store, |this, _, cx| {
this.message_or_context_changed(false, cx);
let _ = this.start_context_load(cx);
}),
];
@@ -152,8 +157,11 @@ impl MessageEditor {
incompatible_tools_state: incompatible_tools.clone(),
workspace,
context_store,
prompt_store,
context_strip,
context_picker_menu_handle,
context_load_task: None,
last_loaded_context: None,
model_selector: cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
@@ -175,6 +183,10 @@ impl MessageEditor {
}
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
@@ -214,6 +226,7 @@ impl MessageEditor {
) {
self.context_picker_menu_handle.toggle(window, cx);
}
pub fn remove_all_context(
&mut self,
_: &RemoveAllContext,
@@ -270,68 +283,22 @@ impl MessageEditor {
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let context_task = self.load_context(cx);
let window_handle = window.window_handle();
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
wait_for_images.await;
cx.spawn(async move |_this, cx| {
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
let loaded_context = loaded_context.unwrap_or_default();
thread
.update(cx, |thread, cx| {
let context = context_store.read(cx).context().clone();
thread.insert_user_message(user_message, context, checkpoint, cx);
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
})
.log_err();
context_store
.update(cx, |context_store, cx| {
let excerpt_ids = context_store
.context()
.iter()
.filter(|ctx| {
matches!(
ctx,
AssistantContext::Selection(_) | AssistantContext::Image(_)
)
})
.map(|ctx| ctx.id())
.collect::<Vec<_>>();
for id in excerpt_ids {
context_store.remove_context(id, cx);
}
})
.log_err();
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
.log_err()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
.log_err();
wait_for_summaries.await;
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
.log_err();
}
// Send to model after summaries are done
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
@@ -342,6 +309,30 @@ impl MessageEditor {
.detach();
}
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
let context_store = self.context_store.clone();
cx.spawn(async move |this, cx| {
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
.ok()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
.ok();
wait_for_summaries.await;
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
.ok();
}
})
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cancelled = self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
@@ -1015,6 +1006,48 @@ impl MessageEditor {
self.update_token_count_task.is_some()
}
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
self.message_or_context_changed(true, cx);
}
fn start_context_load(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
let summaries_task = self.wait_for_summaries(cx);
let load_task = cx.spawn(async move |this, cx| {
// Waits for detailed summaries before `load_context`, as it directly reads these from
// the thread. TODO: Would be cleaner to have context loading await on summarization.
summaries_task.await;
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx))
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
return;
};
let result = load_task.await;
this.update(cx, |this, cx| {
this.last_loaded_context = Some(result);
this.context_load_task = None;
this.message_or_context_changed(false, cx);
})
.ok();
});
// Replace existing load task, if any, causing it to be cancelled.
let load_task = load_task.shared();
self.context_load_task = Some(load_task.clone());
load_task
}
fn load_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
let context_load_task = self.start_context_load(cx);
cx.spawn(async move |this, cx| {
context_load_task.await;
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
.ok()
.flatten()
})
}
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
@@ -1024,9 +1057,7 @@ impl MessageEditor {
return;
};
let context_store = self.context_store.clone();
let editor = self.editor.clone();
let thread = self.thread.clone();
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
@@ -1035,27 +1066,33 @@ impl MessageEditor {
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = context_store.read(cx).context().iter();
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let token_count = if let Some(task) = this.update(cx, |this, cx| {
let loaded_context = this
.last_loaded_context
.as_ref()
.map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
{
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
if let Some(loaded_context) = loaded_context {
loaded_context.add_to_request_message(&mut request_message);
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,

View File

@@ -32,7 +32,7 @@ impl TerminalCodegen {
}
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
@@ -45,6 +45,7 @@ impl TerminalCodegen {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
let prompt = prompt_task.await;
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;

View File

@@ -1,4 +1,4 @@
use crate::context::attach_context_to_message;
use crate::context::load_context;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::{
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::{PromptBuilder, PromptStore};
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::TerminalView;
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
terminal_view: &Entity<TerminalView>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
prompt_editor,
workspace.clone(),
context_store,
prompt_store,
window,
cx,
);
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
.log_err();
let codegen = assist.codegen.clone();
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
return;
};
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
}
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
&self,
assist_id: TerminalInlineAssistId,
cx: &mut App,
) -> Result<LanguageModelRequest> {
) -> Result<Task<LanguageModelRequest>> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let shell = std::env::var("SHELL").ok();
@@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
&latest_output,
)?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
let contexts = assist
.context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
let project = workspace.project();
load_context(contexts, project, &assist.prompt_store, cx)
})?;
attach_context_to_message(
&mut request_message,
assist.context_store.read(cx).context().iter(),
cx,
);
Ok(cx.background_spawn(async move {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
request_message.content.push(prompt.into());
context_load_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
}
}))
}
fn finish_assist(
@@ -380,6 +394,7 @@ struct TerminalInlineAssist {
codegen: Entity<TerminalCodegen>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -390,6 +405,7 @@ impl TerminalInlineAssist {
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut App,
) -> Self {
@@ -400,6 +416,7 @@ impl TerminalInlineAssist {
codegen: codegen.clone(),
workspace: workspace.clone(),
context_store,
prompt_store,
_subscriptions: vec![
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
TerminalInlineAssistant::update_global(cx, |this, cx| {

View File

@@ -8,7 +8,7 @@ use anyhow::{Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
@@ -17,8 +17,8 @@ use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
@@ -35,7 +35,7 @@ use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
@@ -98,8 +98,7 @@ pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
pub context: String,
pub images: Vec<LanguageModelImage>,
pub loaded_context: LoadedContext,
}
impl Message {
@@ -138,8 +137,8 @@ impl Message {
pub fn to_string(&self) -> String {
let mut result = String::new();
if !self.context.is_empty() {
result.push_str(&self.context);
if !self.loaded_context.text.is_empty() {
result.push_str(&self.loaded_context.text);
}
for segment in &self.segments {
@@ -294,8 +293,6 @@ pub struct Thread {
messages: Vec<Message>,
next_message_id: MessageId,
last_prompt_id: PromptId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
@@ -345,8 +342,6 @@ impl Thread {
messages: Vec::new(),
next_message_id: MessageId(0),
last_prompt_id: PromptId::new(),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@@ -391,8 +386,7 @@ impl Thread {
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
let tool_use =
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
Self {
id,
@@ -419,14 +413,15 @@ impl Thread {
}
})
.collect(),
context: message.context,
images: Vec::new(),
loaded_context: LoadedContext {
contexts: Vec::new(),
text: message.context,
images: Vec::new(),
},
})
.collect(),
next_message_id,
last_prompt_id: PromptId::new(),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
@@ -524,7 +519,12 @@ impl Thread {
}
pub fn message(&self, id: MessageId) -> Option<&Message> {
self.messages.iter().find(|message| message.id == id)
let index = self
.messages
.binary_search_by(|message| message.id.cmp(&id))
.ok()?;
self.messages.get(index)
}
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
@@ -656,21 +656,43 @@ impl Thread {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
self.context_by_message.remove(&deleted_message.id);
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
self.context_by_message
.get(&id)
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
self.messages
.iter()
.find(|message| message.id == id)
.into_iter()
.flat_map(|context| {
context
.iter()
.filter_map(|context_id| self.context.get(&context_id))
.flat_map(|message| message.loaded_context.contexts.iter())
}
pub fn is_turn_end(&self, ix: usize) -> bool {
if self.messages.is_empty() {
return false;
}
if !self.is_generating() && ix == self.messages.len() - 1 {
return true;
}
let Some(message) = self.messages.get(ix) else {
return false;
};
if message.role != Role::Assistant {
return false;
}
self.messages
.get(ix + 1)
.and_then(|message| {
self.message(message.id)
.map(|next_message| next_message.role == Role::User)
})
.unwrap_or(false)
}
/// Returns whether all of the tool uses have finished running.
@@ -687,8 +709,11 @@ impl Thread {
self.tool_use.tool_uses_for_message(id, cx)
}
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(id)
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
self.tool_use.tool_results_for_message(assistant_message_id)
}
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
@@ -703,95 +728,27 @@ impl Thread {
self.tool_use.tool_result_card(id).cloned()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
/// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>(
&self,
context: impl Iterator<Item = &'a AssistantContext>,
) -> impl Iterator<Item = &'a AssistantContext> {
context.filter(|ctx| self.is_context_new(ctx))
}
fn is_context_new(&self, context: &AssistantContext) -> bool {
!self.context.contains_key(&context.id())
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
context: Vec<AssistantContext>,
loaded_context: ContextLoadResult,
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let text = text.into();
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| self.is_context_new(ctx))
.collect();
if !new_context.is_empty() {
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.context = context_string;
}
}
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.images = new_context
.iter()
.filter_map(|context| {
if let AssistantContext::Image(image_context) = context {
image_context.image_task.clone().now_or_never().flatten()
} else {
None
}
})
.collect::<Vec<_>>();
}
if !loaded_context.referenced_buffers.is_empty() {
self.action_log.update(cx, |log, cx| {
// Track all buffers added as context
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.track_buffer(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
}
AssistantContext::Selection(selection_context) => {
log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
}
AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_)
| AssistantContext::Image(_) => {}
}
for buffer in loaded_context.referenced_buffers {
log.track_buffer(buffer, cx);
}
});
}
let context_ids = new_context
.iter()
.map(|context| context.id())
.collect::<Vec<_>>();
self.context.extend(
new_context
.into_iter()
.map(|context| (context.id(), context)),
let message_id = self.insert_message(
Role::User,
vec![MessageSegment::Text(text.into())],
loaded_context.loaded_context,
cx,
);
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
@@ -805,10 +762,19 @@ impl Thread {
message_id
}
pub fn insert_assistant_message(
&mut self,
segments: Vec<MessageSegment>,
cx: &mut Context<Self>,
) -> MessageId {
self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
}
pub fn insert_message(
&mut self,
role: Role,
segments: Vec<MessageSegment>,
loaded_context: LoadedContext,
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
@@ -816,8 +782,7 @@ impl Thread {
id,
role,
segments,
context: String::new(),
images: Vec::new(),
loaded_context,
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@@ -846,7 +811,6 @@ impl Thread {
return false;
};
self.messages.remove(index);
self.context_by_message.remove(&id);
self.touch_updated_at();
cx.emit(ThreadEvent::MessageDeleted(id));
true
@@ -933,7 +897,7 @@ impl Thread {
content: tool_result.content.clone(),
})
.collect(),
context: message.context.clone(),
context: message.loaded_context.text.clone(),
})
.collect(),
initial_project_snapshot,
@@ -967,26 +931,21 @@ impl Thread {
let mut request = self.to_completion_request(cx);
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
tools.extend(
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
}),
);
tools
};
request.tools = self
.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
// Skip tools that cannot be supported
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema,
})
})
.collect();
}
self.stream_completion(request, model, window, cx);
@@ -1051,29 +1010,9 @@ impl Thread {
cache: false,
};
self.tool_use
.attach_tool_results(message.id, &mut request_message);
if !message.context.is_empty() {
request_message
.content
.push(MessageContent::Text(message.context.to_string()));
}
if !message.images.is_empty() {
// Some providers only support image parts after an initial text part
if request_message.content.is_empty() {
request_message
.content
.push(MessageContent::Text("Images attached by user:".to_string()));
}
for image in &message.images {
request_message
.content
.push(MessageContent::Image(image.clone()))
}
}
message
.loaded_context
.add_to_request_message(&mut request_message);
for segment in &message.segments {
match segment {
@@ -1103,10 +1042,11 @@ impl Thread {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
// Skip empty messages to avoid sending them to the LLM model
// if !request_message.content.is_empty() {
request.messages.push(request_message);
// }
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
request.messages.push(tool_results_message);
}
}
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
@@ -1136,11 +1076,6 @@ impl Thread {
cache: false,
};
// Skip tool results during summarization.
if self.tool_use.message_has_tool_results(message.id) {
continue;
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => request_message
@@ -1245,22 +1180,45 @@ impl Thread {
.ok();
}
let mut request_assistant_message_id = None;
while let Some(event) = events.next().await {
if let Some((_, response_events)) = request_callback_parameters.as_mut() {
response_events
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
let event = event?;
thread.update(cx, |thread, cx| {
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(String::new())],
let event = match event {
Ok(event) => event,
Err(LanguageModelCompletionError::BadInputJson {
id,
tool_name,
raw_input: invalid_input_json,
json_parse_error,
}) => {
thread.receive_invalid_tool_json(
id,
tool_name,
invalid_input_json,
json_parse_error,
window,
cx,
);
return Ok(());
}
Err(LanguageModelCompletionError::Other(error)) => {
return Err(error);
}
};
match event {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(String::new())],
cx,
));
}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@@ -1275,7 +1233,9 @@ impl Thread {
LanguageModelCompletionEvent::Text(chunk) => {
cx.emit(ThreadEvent::ReceivedTextChunk);
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_text(&chunk);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
@@ -1287,11 +1247,11 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text(chunk.to_string())],
cx,
);
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Text(chunk.to_string())],
cx,
));
};
}
}
@@ -1300,7 +1260,9 @@ impl Thread {
signature,
} => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
if last_message.role == Role::Assistant
&& !thread.tool_use.has_tool_results(last_message.id)
{
last_message.push_thinking(&chunk, signature);
cx.emit(ThreadEvent::StreamedAssistantThinking(
last_message.id,
@@ -1312,25 +1274,25 @@ impl Thread {
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx,
);
request_assistant_message_id =
Some(thread.insert_assistant_message(
vec![MessageSegment::Thinking {
text: chunk.to_string(),
signature,
}],
cx,
));
};
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
let last_assistant_message_id = thread
.messages
.iter_mut()
.rfind(|message| message.role == Role::Assistant)
.map(|message| message.id)
let last_assistant_message_id = request_assistant_message_id
.unwrap_or_else(|| {
thread.insert_message(Role::Assistant, vec![], cx)
let new_assistant_message_id =
thread.insert_assistant_message(vec![], cx);
request_assistant_message_id =
Some(new_assistant_message_id);
new_assistant_message_id
});
let tool_use_id = tool_use.id.clone();
@@ -1362,7 +1324,8 @@ impl Thread {
cx.notify();
thread.auto_capture_telemetry(cx);
})?;
Ok(())
})??;
smol::future::yield_now().await;
}
@@ -1653,6 +1616,41 @@ impl Thread {
pending_tool_uses
}
pub fn receive_invalid_tool_json(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
invalid_json: Arc<str>,
error: String,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
log::error!("The model returned invalid input JSON: {invalid_json}");
let pending_tool_use = self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
} else {
log::error!(
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
);
format!("Unknown tool {}", tool_use_id).into()
};
cx.emit(ThreadEvent::InvalidToolInput {
tool_use_id: tool_use_id.clone(),
ui_text,
invalid_input_json: invalid_json,
});
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}
pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
@@ -1726,32 +1724,21 @@ impl Thread {
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
dbg!("All tools finished");
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
dbg!("Attach tool results");
self.attach_tool_results(cx);
if !canceled {
self.send_to_model(model, window, cx);
}
self.auto_capture_telemetry(cx);
}
}
println!("ToolFinished {}", tool_use_id);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
});
}
/// Insert an empty message to be populated with tool results upon send.
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Tool results are assumed to be waiting on the next message id, so they will populate
// this empty message before sending to model. Would prefer this to be more straightforward.
self.insert_message(Role::User, vec![], cx);
self.auto_capture_telemetry(cx);
}
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
@@ -1760,22 +1747,19 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
let canceled = if self.pending_completions.pop().is_some() {
true
} else {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
self.tool_finished(
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
window,
cx,
);
}
canceled
};
let mut canceled = self.pending_completions.pop().is_some();
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
self.tool_finished(
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
window,
cx,
);
}
self.finalize_pending_checkpoint(cx);
canceled
}
@@ -2026,8 +2010,16 @@ impl Thread {
}
)?;
if !message.context.is_empty() {
writeln!(markdown, "{}", message.context)?;
if !message.loaded_context.text.is_empty() {
writeln!(markdown, "{}", message.loaded_context.text)?;
}
if !message.loaded_context.images.is_empty() {
writeln!(
markdown,
"\n{} images attached as context.\n",
message.loaded_context.images.len()
)?;
}
for segment in &message.segments {
@@ -2056,7 +2048,7 @@ impl Thread {
}
for tool_result in self.tool_results_for_message(message.id) {
write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
if tool_result.is_error {
write!(markdown, " (Error)")?;
}
@@ -2105,7 +2097,7 @@ impl Thread {
}
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
return;
}
@@ -2268,6 +2260,11 @@ pub enum ThreadEvent {
ui_text: Arc<str>,
input: serde_json::Value,
},
InvalidToolInput {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
invalid_input_json: Arc<str>,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
@@ -2297,7 +2294,7 @@ struct PendingCompletion {
#[cfg(test)]
mod tests {
use super::*;
use crate::{ThreadStore, context_store::ContextStore, thread_store};
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use context_server::ContextServerSettings;
use editor::EditorSettings;
@@ -2328,12 +2325,14 @@ mod tests {
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Please explain this code", vec![context], None, cx)
thread.insert_user_message("Please explain this code", loaded_context, None, cx)
});
// Check content and context in message object
@@ -2367,7 +2366,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
assert_eq!(message.context, expected_context);
assert_eq!(message.loaded_context.text, expected_context);
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2394,48 +2393,50 @@ fn main() {{
let (_, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open files individually
// First message with context 1
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
// Get the context objects
let contexts = context_store.update(cx, |store, _| store.context().clone());
assert_eq!(contexts.len(), 3);
// First message with context 1
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
thread.insert_user_message("Message 1", loaded_context, None, cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 2",
vec![contexts[0].clone(), contexts[1].clone()],
None,
cx,
)
thread.insert_user_message("Message 2", loaded_context, None, cx)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
//
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await;
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 3",
vec![
contexts[0].clone(),
contexts[1].clone(),
contexts[2].clone(),
],
None,
cx,
)
thread.insert_user_message("Message 3", loaded_context, None, cx)
});
// Check what contexts are included in each message
@@ -2448,16 +2449,16 @@ fn main() {{
});
// First message should include context 1
assert!(message1.context.contains("file1.rs"));
assert!(message1.loaded_context.text.contains("file1.rs"));
// Second message should include only context 2 (not 1)
assert!(!message2.context.contains("file1.rs"));
assert!(message2.context.contains("file2.rs"));
assert!(!message2.loaded_context.text.contains("file1.rs"));
assert!(message2.loaded_context.text.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
assert!(!message3.context.contains("file1.rs"));
assert!(!message3.context.contains("file2.rs"));
assert!(message3.context.contains("file3.rs"));
assert!(!message3.loaded_context.text.contains("file1.rs"));
assert!(!message3.loaded_context.text.contains("file2.rs"));
assert!(message3.loaded_context.text.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2494,7 +2495,12 @@ fn main() {{
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
thread.insert_user_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
None,
cx,
)
});
// Check content and context in message object
@@ -2507,7 +2513,7 @@ fn main() {{
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
assert_eq!(message.context, "");
assert_eq!(message.loaded_context.text, "");
// Check message in request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2520,12 +2526,17 @@ fn main() {{
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Are there any good books?", vec![], None, cx)
thread.insert_user_message(
"Are there any good books?",
ContextLoadResult::default(),
None,
cx,
)
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
assert_eq!(message2.context, "");
assert_eq!(message2.loaded_context.text, "");
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
@@ -2559,12 +2570,14 @@ fn main() {{
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
let loaded_context = cx
.update(|cx| load_context(vec![context], &project, &None, cx))
.await;
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", vec![context], None, cx)
thread.insert_user_message("Explain this code", loaded_context, None, cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
@@ -2592,7 +2605,12 @@ fn main() {{
// Insert another user message without context
thread.update(cx, |thread, cx| {
thread.insert_user_message("What does the code do now?", vec![], None, cx)
thread.insert_user_message(
"What does the code do now?",
ContextLoadResult::default(),
None,
cx,
)
});
// Create a new request and check for the stale buffer warning
@@ -2659,6 +2677,7 @@ fn main() {{
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
@@ -2683,15 +2702,15 @@ fn main() {{
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.update(cx, |project, cx| {
project.open_buffer(buffer_path.clone(), cx)
})
.await
.unwrap();
context_store
.update(cx, |store, cx| {
store.add_file_from_buffer(buffer.clone(), cx)
})
.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
});
Ok(buffer)
}

View File

@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@@ -82,12 +82,11 @@ impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
let prompt_store = prompt_store.await.ok();
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
@@ -349,25 +348,8 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
@@ -379,16 +361,12 @@ impl ThreadStore {
self.threads.len()
}
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads
}
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
self.threads().into_iter().take(limit).collect()
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| {
Thread::new(
@@ -639,12 +617,17 @@ pub struct SerializedThread {
}
impl SerializedThread {
pub const VERSION: &'static str = "0.1.0";
pub const VERSION: &'static str = "0.2.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
SerializedThreadV0_1_0::VERSION => {
let saved_thread =
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
Ok(saved_thread.upgrade())
}
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
saved_thread_json,
)?),
@@ -666,6 +649,38 @@ impl SerializedThread {
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedThreadV0_1_0(
// The structure did not change, so we are reusing the latest SerializedThread.
// When making the next version, make sure this points to SerializedThreadV0_2_0
SerializedThread,
);
impl SerializedThreadV0_1_0 {
pub const VERSION: &'static str = "0.1.0";
pub fn upgrade(self) -> SerializedThread {
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages {
if message.role == Role::User && !message.tool_results.is_empty() {
if let Some(last_message) = messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.tool_results = message.tool_results;
continue;
}
}
messages.push(message);
}
SerializedThread { messages, ..self.0 }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
pub id: MessageId,

View File

@@ -30,7 +30,6 @@ pub struct ToolUse {
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
@@ -42,7 +41,6 @@ impl ToolUseState {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
@@ -56,7 +54,6 @@ impl ToolUseState {
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
@@ -68,7 +65,6 @@ impl ToolUseState {
let tool_uses = message
.tool_uses
.iter()
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
@@ -86,14 +82,6 @@ impl ToolUseState {
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
}
}
Role::User => {
if !message.tool_results.is_empty() {
let tool_uses_by_user_message = this
.tool_uses_by_user_message
.entry(message.id)
.or_default();
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
@@ -102,11 +90,6 @@ impl ToolUseState {
continue;
};
if !(filter_by_tool_name)(tool_use.as_ref()) {
continue;
}
tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
@@ -119,7 +102,7 @@ impl ToolUseState {
}
}
}
Role::System => {}
Role::System | Role::User => {}
}
}
@@ -229,20 +212,26 @@ impl ToolUseState {
}
}
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
let empty = Vec::new();
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
self.tool_uses_by_user_message
.get(&message_id)
.unwrap_or(&empty)
tool_uses
.iter()
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_uses_by_user_message
.get(&message_id)
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
@@ -294,14 +283,6 @@ impl ToolUseState {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
@@ -450,7 +431,6 @@ impl ToolUseState {
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
dbg!(&self.tool_uses_by_assistant_message, &message_id);
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
@@ -468,32 +448,49 @@ impl ToolUseState {
}
}
pub fn attach_tool_results(
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results_message(
&self,
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
dbg!(&self.tool_uses_by_user_message, &message_id);
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
for tool_use_id in tool_uses {
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
request_message.content.push(MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
assistant_message_id: MessageId,
) -> Option<LanguageModelRequestMessage> {
let tool_uses = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)?;
if tool_uses.is_empty() {
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
for tool_use in tool_uses {
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
request_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
));
}
}));
}
}
Some(request_message)
}
}

View File

@@ -1,14 +1,15 @@
use std::sync::Arc;
use std::{rc::Rc, time::Duration};
use file_icons::FileIcons;
use futures::FutureExt;
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
use gpui::{ClickEvent, Task};
use language_model::LanguageModelImage;
use gpui::{
Animation, AnimationExt as _, AnyView, ClickEvent, Entity, MouseButton, pulsating_between,
};
use project::Project;
use prompt_store::PromptStore;
use text::OffsetRangeExt;
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
use crate::context::{AgentContext, ContextKind, ImageStatus};
#[derive(IntoElement)]
pub enum ContextPill {
@@ -73,9 +74,7 @@ impl ContextPill {
pub fn id(&self) -> ElementId {
match self {
Self::Added { context, .. } => {
ElementId::NamedInteger("context-pill".into(), context.id.0)
}
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
Self::Suggested { .. } => "suggested-context-pill".into(),
}
}
@@ -170,14 +169,9 @@ impl RenderOnce for ContextPill {
.when_some(
context.render_preview.as_ref(),
|element, render_preview| {
element.hoverable_tooltip({
let render_preview = render_preview.clone();
move |_, cx| {
cx.new(|_| ContextPillPreview {
render_preview: render_preview.clone(),
})
.into()
}
let render_preview = render_preview.clone();
element.hoverable_tooltip(move |window, cx| {
render_preview(window, cx)
})
},
)
@@ -199,14 +193,17 @@ impl RenderOnce for ContextPill {
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
IconButton::new(("remove", context.id.0), IconName::Close)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
IconButton::new(
context.context.element_id("remove".into()),
IconName::Close,
)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
)
})
.when_some(on_click.as_ref(), |element, on_click| {
@@ -262,23 +259,34 @@ pub enum ContextStatus {
Error { message: SharedString },
}
#[derive(RegisterComponent)]
// TODO: Component commented out due to new dependency on `Project`.
//
// #[derive(RegisterComponent)]
pub struct AddedContext {
pub id: ContextId,
pub context: AgentContext,
pub kind: ContextKind,
pub name: SharedString,
pub parent: Option<SharedString>,
pub tooltip: Option<SharedString>,
pub icon_path: Option<SharedString>,
pub status: ContextStatus,
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement + 'static>>,
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView + 'static>>,
}
impl AddedContext {
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
///
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
pub fn new(
context: AgentContext,
prompt_store: Option<&Entity<PromptStore>>,
project: &Project,
cx: &App,
) -> Option<AddedContext> {
match context {
AssistantContext::File(file_context) => {
let full_path = file_context.context_buffer.full_path(cx);
AgentContext::File(ref file_context) => {
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -289,8 +297,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: file_context.id,
Some(AddedContext {
kind: ContextKind::File,
name,
parent,
@@ -298,18 +305,16 @@ impl AddedContext {
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Directory(directory_context) => {
let worktree = directory_context.worktree.read(cx);
// If the directory no longer exists, use its last known path.
let full_path = worktree
.entry_for_id(directory_context.entry_id)
.map_or_else(
|| directory_context.last_path.clone(),
|entry| worktree.full_path(&entry.path).into(),
);
AgentContext::Directory(ref directory_context) => {
let worktree = project
.worktree_for_entry(directory_context.entry_id, cx)?
.read(cx);
let entry = worktree.entry_for_id(directory_context.entry_id)?;
let full_path = worktree.full_path(&entry.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -320,8 +325,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: directory_context.id,
Some(AddedContext {
kind: ContextKind::Directory,
name,
parent,
@@ -329,33 +333,34 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Symbol(symbol_context) => AddedContext {
id: symbol_context.id,
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
kind: ContextKind::Symbol,
name: symbol_context.context_symbol.id.name.clone(),
name: symbol_context.symbol.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Selection(selection_context) => {
let full_path = selection_context.context_buffer.full_path(cx);
AgentContext::Selection(ref selection_context) => {
let buffer = selection_context.buffer.read(cx);
let full_path = buffer.file()?.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| full_path_string.clone());
let line_range_text = format!(
" ({}-{})",
selection_context.line_range.start.row + 1,
selection_context.line_range.end.row + 1
);
let line_range = selection_context.range.to_point(&buffer.snapshot());
let line_range_text =
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
full_path_string.push_str(&line_range_text);
name.push_str(&line_range_text);
@@ -365,8 +370,7 @@ impl AddedContext {
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: selection_context.id,
Some(AddedContext {
kind: ContextKind::Selection,
name: name.into(),
parent,
@@ -374,22 +378,31 @@ impl AddedContext {
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: Some(Rc::new({
let content = selection_context.context_buffer.text.clone();
let buffer = selection_context.buffer.clone();
let range = selection_context.range.clone();
move |_, cx| {
div()
.id("context-pill-selection-preview")
.overflow_scroll()
.max_w_128()
.max_h_96()
.child(Label::new(content.clone()).buffer_font(cx))
.into_any_element()
let text: SharedString = buffer
.read(cx)
.text_for_range(range.clone())
.collect::<String>()
.into();
ContextPillPreview::new(cx, move |_, cx| {
div()
.id("context-pill-selection-preview")
.overflow_scroll()
.max_w_128()
.max_h_96()
.child(Label::new(text.clone()).buffer_font(cx))
.into_any_element()
})
.into()
}
})),
}
context,
})
}
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
id: fetched_url_context.id,
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
kind: ContextKind::FetchedUrl,
name: fetched_url_context.url.clone(),
parent: None,
@@ -397,12 +410,12 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Thread(thread_context) => AddedContext {
id: thread_context.id,
AgentContext::Thread(ref thread_context) => Some(AddedContext {
kind: ContextKind::Thread,
name: thread_context.summary(cx),
name: thread_context.name(cx),
parent: None,
tooltip: None,
icon_path: None,
@@ -418,53 +431,74 @@ impl AddedContext {
ContextStatus::Ready
},
render_preview: None,
},
context,
}),
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
AgentContext::Rules(ref user_rules_context) => {
let name = prompt_store
.as_ref()?
.read(cx)
.metadata(user_rules_context.prompt_id.into())?
.title?;
Some(AddedContext {
kind: ContextKind::Rules,
name: name.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
context,
})
}
AssistantContext::Image(image_context) => AddedContext {
id: image_context.id,
AgentContext::Image(ref image_context) => Some(AddedContext {
kind: ContextKind::Image,
name: "Image".into(),
parent: None,
tooltip: None,
icon_path: None,
status: if image_context.is_loading() {
ContextStatus::Loading {
status: match image_context.status() {
ImageStatus::Loading => ContextStatus::Loading {
message: "Loading…".into(),
}
} else if image_context.is_error() {
ContextStatus::Error {
},
ImageStatus::Error => ContextStatus::Error {
message: "Failed to load image".into(),
}
} else {
ContextStatus::Ready
},
ImageStatus::Ready => ContextStatus::Ready,
},
render_preview: Some(Rc::new({
let image = image_context.original_image.clone();
move |_, _| {
gpui::img(image.clone())
.max_w_96()
.max_h_96()
.into_any_element()
move |_, cx| {
let image = image.clone();
ContextPillPreview::new(cx, move |_, _| {
gpui::img(image.clone())
.max_w_96()
.max_h_96()
.into_any_element()
})
.into()
}
})),
},
context,
}),
}
}
}
struct ContextPillPreview {
render_preview: Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>,
render_preview: Box<dyn Fn(&mut Window, &mut App) -> AnyElement>,
}
impl ContextPillPreview {
fn new(
cx: &mut App,
render_preview: impl Fn(&mut Window, &mut App) -> AnyElement + 'static,
) -> Entity<Self> {
cx.new(|_| Self {
render_preview: Box::new(render_preview),
})
}
}
impl Render for ContextPillPreview {
@@ -478,6 +512,8 @@ impl Render for ContextPillPreview {
}
}
// TODO: Component commented out due to new dependency on `Project`.
/*
impl Component for AddedContext {
fn scope() -> ComponentScope {
ComponentScope::Agent
@@ -487,12 +523,13 @@ impl Component for AddedContext {
"AddedContext"
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let next_context_id = ContextId::zero();
let image_ready = (
"Ready",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(0),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
}),
@@ -503,8 +540,8 @@ impl Component for AddedContext {
let image_loading = (
"Loading",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(1),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: cx
.background_spawn(async move {
@@ -520,8 +557,8 @@ impl Component for AddedContext {
let image_error = (
"Error",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(2),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
}),
@@ -544,5 +581,8 @@ impl Component for AddedContext {
)
.into_any(),
)
None
}
}
*/

View File

@@ -1,5 +1,3 @@
mod supported_countries;
use std::str::FromStr;
use anyhow::{Context as _, Result, anyhow};
@@ -11,8 +9,6 @@ use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
pub use supported_countries::*;
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]

View File

@@ -1,225 +0,0 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by Anthropic.
///
/// <https://www.anthropic.com/supported-countries>
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by Anthropic.
///
/// https://www.anthropic.com/supported-countries
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"AL", // Albania
"DZ", // Algeria
"AS", // American Samoa (US)
"AD", // Andorra
"AO", // Angola
"AI", // Anguilla (UK)
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BM", // Bermuda (UK)
"BT", // Bhutan
"BO", // Bolivia
"BA", // Bosnia and Herzegovina
"BW", // Botswana
"BR", // Brazil
"IO", // British Indian Ocean Territory (UK)
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"KY", // Cayman Islands (UK)
"TD", // Chad
"CL", // Chile
"CX", // Christmas Island (AU)
"CC", // Cocos (Keeling) Islands (AU)
"CO", // Colombia
"KM", // Comoros
"CG", // Congo (Brazzaville)
"CK", // Cook Islands (NZ)
"CR", // Costa Rica
"CI", // Côte d'Ivoire
"HR", // Croatia
"CY", // Cyprus
"CZ", // Czechia (Czech Republic)
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"EE", // Estonia
"SZ", // Eswatini
"FK", // Falkland Islands (UK)
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GF", // French Guiana (FR)
"PF", // French Polynesia (FR)
"TF", // French Southern Territories
"GA", // Gabon
"GM", // Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GI", // Gibraltar (UK)
"GR", // Greece
"GD", // Grenada
"GT", // Guatemala
"GU", // Guam (US)
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"HM", // Heard Island and McDonald Islands (AU)
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KW", // Kuwait
"KG", // Kyrgyzstan
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MD", // Moldova
"MC", // Monaco
"MN", // Mongolia
"MS", // Montserrat (UK)
"ME", // Montenegro
"MA", // Morocco
"MZ", // Mozambique
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NZ", // New Zealand
"NE", // Niger
"NG", // Nigeria
"NF", // Norfolk Island (AU)
"MK", // North Macedonia
"MI", // Northern Mariana Islands (UK)
"NO", // Norway
"NU", // Niue (NZ)
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PN", // Pitcairn (UK)
"PL", // Poland
"PT", // Portugal
"PR", // Puerto Rico (US)
"QA", // Qatar
"RO", // Romania
"RW", // Rwanda
"BL", // Saint Barthélemy (FR)
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"MF", // Saint Martin (FR)
"PM", // Saint Pierre and Miquelon (FR)
"VC", // Saint Vincent and the Grenadines
"WS", // Samoa
"SM", // San Marino
"ST", // São Tomé and Príncipe
"SA", // Saudi Arabia
"SN", // Senegal
"RS", // Serbia
"SC", // Seychelles
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"ZA", // South Africa
"KR", // South Korea
"ES", // Spain
"LK", // Sri Lanka
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste
"TG", // Togo
"TK", // Tokelau (NZ)
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Türkiye (Turkey)
"TM", // Turkmenistan
"TC", // Turks and Caicos Islands (UK)
"TV", // Tuvalu
"UG", // Uganda
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
"AE", // United Arab Emirates
"GB", // United Kingdom
"UM", // United States Minor Outlying Islands (US)
"US", // United States of America
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VA", // Vatican City
"VN", // Vietnam
"VI", // Virgin Islands (US)
"VG", // Virgin Islands (UK)
"WF", // Wallis and Futuna (FR)
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@@ -49,7 +49,7 @@ menu.workspace = true
multi_buffer.workspace = true
parking_lot.workspace = true
project.workspace = true
prompt_library.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
rope.workspace = true

View File

@@ -101,7 +101,7 @@ pub fn init(
SlashCommandSettings::register(cx);
assistant_context_editor::init(client.clone(), cx);
prompt_library::init(cx);
rules_library::init(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
assistant_tool::init(cx);

View File

@@ -25,8 +25,8 @@ use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, UserPromptId};
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file};
@@ -43,7 +43,7 @@ use workspace::{
dock::{DockPosition, Panel, PanelEvent},
pane,
};
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
pub fn init(cx: &mut App) {
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
@@ -57,11 +57,11 @@ pub fn init(cx: &mut App) {
.register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers)
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_rules_library(action, window, cx)
});
}
});
@@ -272,8 +272,8 @@ impl AssistantPanel {
.action("New Chat", Box::new(NewChat))
.action("History", Box::new(DeployHistory))
.action(
"Prompt Library",
Box::new(OpenPromptLibrary::default()),
"Rules Library",
Box::new(OpenRulesLibrary::default()),
)
.action("Configure", Box::new(ShowConfiguration))
.action(zoom_label, Box::new(ToggleZoom))
@@ -1043,13 +1043,13 @@ impl AssistantPanel {
}
}
fn deploy_prompt_library(
fn deploy_rules_library(
&mut self,
action: &OpenPromptLibrary,
action: &OpenRulesLibrary,
_window: &mut Window,
cx: &mut Context<Self>,
) {
open_prompt_library(
open_rules_library(
self.languages.clone(),
Box::new(PromptLibraryInlineAssist),
Arc::new(|| {
@@ -1059,9 +1059,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -1235,7 +1235,7 @@ impl Render for AssistantPanel {
this.show_configuration_tab(window, cx)
}))
.on_action(cx.listener(AssistantPanel::deploy_history))
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
.child(registrar.size_full().child(self.pane.clone()))
.into_any_element()
}
@@ -1350,13 +1350,13 @@ impl Focusable for AssistantPanel {
struct PromptLibraryInlineAssist;
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
fn assist(
&self,
prompt_editor: &Entity<Editor>,
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut Context<PromptLibrary>,
cx: &mut Context<RulesLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)

View File

@@ -18,11 +18,11 @@ use editor::{
},
};
use feature_flags::{
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedPro,
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedProFeatureFlag,
};
use fs::Fs;
use futures::{
SinkExt, Stream, StreamExt,
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{BoxFuture, LocalBoxFuture},
join,
@@ -1652,7 +1652,7 @@ impl Render for PromptEditor {
let error_message = SharedString::from(error.to_string());
if error.error_code() == proto::ErrorCode::RateLimitExceeded
&& cx.has_flag::<ZedPro>()
&& cx.has_flag::<ZedProFeatureFlag>()
{
el.child(
v_flex()
@@ -1966,7 +1966,7 @@ impl PromptEditor {
.update(cx, |editor, _| editor.set_read_only(false));
}
CodegenStatus::Error(error) => {
if cx.has_flag::<ZedPro>()
if cx.has_flag::<ZedProFeatureFlag>()
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
&& !dismissed_rate_limit_notice()
{
@@ -3056,7 +3056,8 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks =
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();

View File

@@ -10,6 +10,11 @@ pub fn adapt_schema_to_format(
json: &mut Value,
format: LanguageModelToolSchemaFormat,
) -> Result<()> {
if let Value::Object(obj) = json {
obj.remove("$schema");
obj.remove("title");
}
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
@@ -30,10 +35,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
}
}
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
for key in KEYS_TO_REMOVE {
obj.remove(key);
}
obj.remove("format");
if let Some(default) = obj.get("default") {
let is_null = default.is_null();

View File

@@ -59,7 +59,6 @@ use crate::thinking_tool::ThinkingTool;
pub use create_file_tool::CreateFileToolInput;
pub use edit_file_tool::EditFileToolInput;
pub use find_path_tool::FindPathToolInput;
pub use list_directory_tool::ListDirectoryToolInput;
pub use read_file_tool::ReadFileToolInput;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
@@ -111,11 +110,38 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
#[cfg(test)]
mod tests {
use super::*;
use client::Client;
use clock::FakeSystemClock;
use http_client::FakeHttpClient;
use schemars::JsonSchema;
use serde::Serialize;
use super::*;
#[test]
fn test_json_schema() {
#[derive(Serialize, JsonSchema)]
struct GetWeatherTool {
location: String,
}
let schema = schema::json_schema_for::<GetWeatherTool>(
language_model::LanguageModelToolSchemaFormat::JsonSchema,
)
.unwrap();
assert_eq!(
schema,
serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string"
}
},
"required": ["location"],
})
);
}
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {

View File

@@ -14,7 +14,6 @@ pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;

View File

@@ -1,58 +0,0 @@
use super::*;
use crate::db::tables::rate_buckets;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
impl Database {
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
/// than the currently saved timestamp.
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
if buckets.is_empty() {
return Ok(());
}
self.transaction(|tx| async move {
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
rate_buckets::ActiveModel {
user_id: ActiveValue::Set(bucket.user_id),
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
token_count: ActiveValue::Set(bucket.token_count),
last_refill: ActiveValue::Set(bucket.last_refill),
}
}))
.on_conflict(
OnConflict::columns([
rate_buckets::Column::UserId,
rate_buckets::Column::RateLimitName,
])
.update_columns([
rate_buckets::Column::TokenCount,
rate_buckets::Column::LastRefill,
])
.to_owned(),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Retrieves the rate limit for the given user and rate limit name.
pub async fn get_rate_bucket(
&self,
user_id: UserId,
rate_limit_name: &str,
) -> Result<Option<rate_buckets::Model>> {
self.transaction(|tx| async move {
let rate_limit = rate_buckets::Entity::find()
.filter(rate_buckets::Column::UserId.eq(user_id))
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
.one(&*tx)
.await?;
Ok(rate_limit)
})
.await
}
}

View File

@@ -28,7 +28,6 @@ pub mod project;
pub mod project_collaborator;
pub mod project_repository;
pub mod project_repository_statuses;
pub mod rate_buckets;
pub mod room;
pub mod room_participant;
pub mod server;

View File

@@ -1,31 +0,0 @@
use crate::db::UserId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "rate_buckets")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
#[sea_orm(primary_key, auto_increment = false)]
pub rate_limit_name: String,
pub token_count: i32,
pub last_refill: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -6,7 +6,6 @@ pub mod env;
pub mod executor;
pub mod llm;
pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
@@ -25,7 +24,6 @@ pub use cents::*;
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@@ -295,7 +293,6 @@ pub struct AppState {
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
pub config: Config,
@@ -348,7 +345,6 @@ impl AppState {
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()

View File

@@ -13,8 +13,8 @@ use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
env, executor::Executor, rpc::ResultExt,
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
@@ -111,10 +111,6 @@ async fn main() -> Result<()> {
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(
state.rate_limiter.clone(),
state.executor.clone(),
);
let epoch = state
.db

View File

@@ -1,321 +0,0 @@
use crate::{Database, Error, Result, db::UserId, executor::Executor};
use chrono::{DateTime, Duration, Utc};
use dashmap::{DashMap, DashSet};
use rpc::ErrorCodeExt;
use sea_orm::prelude::DateTimeUtc;
use std::sync::Arc;
use util::ResultExt;
pub trait RateLimit: Send + Sync {
fn capacity(&self) -> usize;
fn refill_duration(&self) -> Duration;
fn db_name(&self) -> &'static str;
}
/// Used to enforce per-user rate limits
pub struct RateLimiter {
buckets: DashMap<(UserId, String), RateBucket>,
dirty_buckets: DashSet<(UserId, String)>,
db: Arc<Database>,
}
impl RateLimiter {
pub fn new(db: Arc<Database>) -> Self {
RateLimiter {
buckets: DashMap::new(),
dirty_buckets: DashSet::new(),
db,
}
}
/// Spawns a new task that periodically saves rate limit data to the database.
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
executor.clone().spawn_detached(async move {
loop {
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
rate_limiter.save().await.log_err();
}
});
}
/// Returns an error if the user has exceeded the specified `RateLimit`.
/// Attempts to read the from the database if no cached RateBucket currently exists.
pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
self.check_internal(limit, user_id, Utc::now()).await
}
async fn check_internal(
&self,
limit: &dyn RateLimit,
user_id: UserId,
now: DateTimeUtc,
) -> Result<()> {
let bucket_key = (user_id, limit.db_name().to_string());
// Attempt to fetch the bucket from the database if it hasn't been cached.
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
// but this enforces limits across restarts so long as the database is reachable.
if !self.buckets.contains_key(&bucket_key) {
if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
self.buckets.insert(bucket_key.clone(), bucket);
self.dirty_buckets.insert(bucket_key.clone());
}
}
let mut bucket = self
.buckets
.entry(bucket_key.clone())
.or_insert_with(|| RateBucket::new(limit, now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
Ok(())
} else {
Err(rpc::proto::ErrorCode::RateLimitExceeded
.message("rate limit exceeded".into())
.anyhow())?
}
}
async fn load_bucket(
&self,
limit: &dyn RateLimit,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
.get_rate_bucket(user_id, limit.db_name())
.await?
.map(|saved_bucket| {
RateBucket::from_db(
limit,
saved_bucket.token_count as usize,
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
)
}))
}
pub async fn save(&self) -> Result<()> {
let mut buckets = Vec::new();
self.dirty_buckets.retain(|key| {
if let Some(bucket) = self.buckets.get(key) {
buckets.push(crate::db::rate_buckets::Model {
user_id: key.0,
rate_limit_name: key.1.clone(),
token_count: bucket.token_count as i32,
last_refill: bucket.last_refill.naive_utc(),
});
}
false
});
match self.db.save_rate_buckets(&buckets).await {
Ok(()) => Ok(()),
Err(err) => {
for bucket in buckets {
self.dirty_buckets
.insert((bucket.user_id, bucket.rate_limit_name));
}
Err(err)
}
}
}
}
#[derive(Clone, Debug)]
struct RateBucket {
capacity: usize,
token_count: usize,
refill_time_per_token: Duration,
last_refill: DateTimeUtc,
}
impl RateBucket {
fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
Self {
capacity: limit.capacity(),
token_count: limit.capacity(),
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill: now,
}
}
fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
Self {
capacity: limit.capacity(),
token_count,
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill,
}
}
fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now);
if self.token_count > 0 {
self.token_count -= 1;
true
} else {
false
}
}
fn refill(&mut self, now: DateTimeUtc) {
let elapsed = now - self.last_refill;
if elapsed >= self.refill_time_per_token {
let new_tokens =
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
let unused_refill_time = Duration::milliseconds(
elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
);
self.last_refill = now - unused_refill_time;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{NewUserParams, TestDb};
use gpui::TestAppContext;
#[gpui::test]
async fn test_rate_limiter(cx: &mut TestAppContext) {
let test_db = TestDb::sqlite(cx.executor().clone());
let db = test_db.db().clone();
let user_1 = db
.create_user(
"user-1@zed.dev",
None,
false,
NewUserParams {
github_login: "user-1".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user-2@zed.dev",
None,
false,
NewUserParams {
github_login: "user-2".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let mut now = Utc::now();
let rate_limiter = RateLimiter::new(db.clone());
let rate_limit_a = Box::new(RateLimitA);
let rate_limit_b = Box::new(RateLimitB);
// User 1 can access resource A two times before being rate-limited.
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// User 2 can access resource A and user 1 can access resource B.
rate_limiter
.check_internal(&*rate_limit_b, user_2, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_b, user_1, now)
.await
.unwrap();
// After 1.5s, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(1500);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 500ms, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(500);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
rate_limiter.save().await.unwrap();
// Rate limits are reloaded from the database, so user A is still rate-limited
// for resource A.
let rate_limiter = RateLimiter::new(db.clone());
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 1s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
}
struct RateLimitA;
impl RateLimit for RateLimitA {
fn capacity(&self) -> usize {
2
}
fn refill_duration(&self) -> Duration {
Duration::seconds(2)
}
fn db_name(&self) -> &'static str {
"rate-limit-a"
}
}
struct RateLimitB;
impl RateLimit for RateLimitB {
fn capacity(&self) -> usize {
10
}
fn refill_duration(&self) -> Duration {
Duration::seconds(3)
}
fn db_name(&self) -> &'static str {
"rate-limit-b"
}
}
}

View File

@@ -1,6 +1,7 @@
mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::LlmTokenClaims;
use crate::{
AppState, Error, Result, auth,
@@ -178,15 +179,23 @@ impl Session {
Ok(db.has_active_billing_subscription(user_id).await?)
}
pub async fn current_plan(
&self,
_db: &MutexGuard<'_, DbHandle>,
) -> anyhow::Result<proto::Plan> {
if self.is_staff() {
Ok(proto::Plan::ZedPro)
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
let user_id = self.user_id();
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
Ok(proto::Plan::Free)
}
proto::Plan::Free
};
Ok(plan)
}
fn user_id(&self) -> UserId {

View File

@@ -46,7 +46,7 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
let mut first_user = None;
let mut others = vec![];
let flag_names = ["remoting", "language-models"];
let flag_names = ["language-models"];
let mut flags = Vec::new();
let existing_feature_flags = db.list_feature_flags().await?;

View File

@@ -1,5 +1,5 @@
use crate::{
AppState, Config, RateLimiter,
AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
executor::Executor,
rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion},
@@ -517,7 +517,6 @@ impl TestServer {
blob_store_client: None,
stripe_client: None,
stripe_billing: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
kinesis_client: None,
config: Config {

View File

@@ -552,7 +552,9 @@ impl TcpTransport {
let host = connection_args.host;
let port = connection_args.port;
let mut command = util::command::new_smol_command(&binary.command);
let mut command = util::command::new_std_command(&binary.command);
util::set_pre_exec_to_start_new_session(&mut command);
let mut command = smol::process::Command::from(command);
if let Some(cwd) = &binary.cwd {
command.current_dir(cwd);
@@ -636,7 +638,9 @@ pub struct StdioTransport {
impl StdioTransport {
#[allow(dead_code, reason = "This is used in non test builds of Zed")]
async fn start(binary: &DebugAdapterBinary, _: AsyncApp) -> Result<(TransportPipe, Self)> {
let mut command = util::command::new_smol_command(&binary.command);
let mut command = util::command::new_std_command(&binary.command);
util::set_pre_exec_to_start_new_session(&mut command);
let mut command = smol::process::Command::from(command);
if let Some(cwd) = &binary.cwd {
command.current_dir(cwd);

View File

@@ -1,3 +1,4 @@
use crate::persistence::DebuggerPaneItem;
use crate::{
ClearAllBreakpoints, Continue, CreateDebuggingSession, Disconnect, Pause, Restart, StepBack,
StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints, persistence,
@@ -32,8 +33,10 @@ use settings::Settings;
use std::any::TypeId;
use std::path::Path;
use std::sync::Arc;
use task::{DebugTaskDefinition, DebugTaskTemplate};
use terminal_view::terminal_panel::TerminalPanel;
use task::{
DebugTaskDefinition, DebugTaskTemplate, HideStrategy, RevealStrategy, RevealTarget, TaskId,
};
use terminal_view::TerminalView;
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use workspace::{
Workspace,
@@ -293,23 +296,21 @@ impl DebugPanel {
(session.clone(), dap_store.boot_session(session, cx))
})?;
Self::register_session(this.clone(), session.clone(), cx).await?;
match task.await {
Err(e) => {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
if let Err(e) = task.await {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
Ok(_) => Self::register_session(this, session, cx).await?,
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
anyhow::Ok(())
@@ -467,6 +468,7 @@ impl DebugPanel {
) {
match event {
dap_store::DapStoreEvent::RunInTerminal {
session_id,
title,
cwd,
command,
@@ -476,6 +478,7 @@ impl DebugPanel {
..
} => {
self.handle_run_in_terminal_request(
*session_id,
title.clone(),
cwd.clone(),
command.clone(),
@@ -499,6 +502,7 @@ impl DebugPanel {
fn handle_run_in_terminal_request(
&self,
session_id: SessionId,
title: Option<String>,
cwd: Option<Arc<Path>>,
command: Option<String>,
@@ -506,56 +510,83 @@ impl DebugPanel {
envs: HashMap<String, String>,
mut sender: mpsc::Sender<Result<u32>>,
window: &mut Window,
cx: &mut App,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let terminal_task = self.workspace.update(cx, |workspace, cx| {
let terminal_panel = workspace.panel::<TerminalPanel>(cx).ok_or_else(|| {
anyhow!("RunInTerminal DAP request failed because TerminalPanel wasn't found")
});
let terminal_panel = match terminal_panel {
Ok(panel) => panel,
Err(err) => return Task::ready(Err(err)),
};
terminal_panel.update(cx, |terminal_panel, cx| {
let terminal_task = terminal_panel.add_terminal(
TerminalKind::Debug {
command,
args,
envs,
cwd,
title,
},
task::RevealStrategy::Never,
window,
cx,
);
cx.spawn(async move |_, cx| {
let pid_task = async move {
let terminal = terminal_task.await?;
terminal.read_with(cx, |terminal, _| terminal.pty_info.pid())
};
pid_task.await
})
let Some(session) = self
.sessions
.iter()
.find(|s| s.read(cx).session_id(cx) == session_id)
else {
return Task::ready(Err(anyhow!("no session {:?} found", session_id)));
};
let running = session.read(cx).running_state();
let cwd = cwd.map(|p| p.to_path_buf());
let shell = self
.project
.read(cx)
.terminal_settings(&cwd, cx)
.shell
.clone();
let kind = if let Some(command) = command {
let title = title.clone().unwrap_or(command.clone());
TerminalKind::Task(task::SpawnInTerminal {
id: TaskId("debug".to_string()),
full_label: title.clone(),
label: title.clone(),
command: command.clone(),
args,
command_label: title.clone(),
cwd,
env: envs,
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: RevealStrategy::NoFocus,
reveal_target: RevealTarget::Dock,
hide: HideStrategy::Never,
shell,
show_summary: false,
show_command: false,
show_rerun: false,
})
} else {
TerminalKind::Shell(cwd.map(|c| c.to_path_buf()))
};
let workspace = self.workspace.clone();
let project = self.project.downgrade();
let terminal_task = self.project.update(cx, |project, cx| {
project.create_terminal(kind, window.window_handle(), cx)
});
let terminal_task = cx.spawn_in(window, async move |_, cx| {
let terminal = terminal_task.await?;
let terminal_view = cx.new_window_entity(|window, cx| {
TerminalView::new(terminal.clone(), workspace, None, project, window, cx)
})?;
running.update_in(cx, |running, window, cx| {
running.ensure_pane_item(DebuggerPaneItem::Terminal, window, cx);
running.debug_terminal.update(cx, |debug_terminal, cx| {
debug_terminal.terminal = Some(terminal_view);
cx.notify();
});
})?;
anyhow::Ok(terminal.read_with(cx, |terminal, _| terminal.pty_info.pid())?)
});
cx.background_spawn(async move {
match terminal_task {
Ok(pid_task) => match pid_task.await {
Ok(Some(pid)) => sender.send(Ok(pid.as_u32())).await?,
Ok(None) => {
match terminal_task.await {
Ok(pid_task) => match pid_task {
Some(pid) => sender.send(Ok(pid.as_u32())).await?,
None => {
sender
.send(Err(anyhow!(
"Terminal was spawned but PID was not available"
)))
.await?
}
Err(error) => sender.send(Err(anyhow!(error))).await?,
},
Err(error) => sender.send(Err(anyhow!(error))).await?,
};

View File

@@ -1,7 +1,7 @@
use dap::debugger_settings::DebuggerSettings;
use debugger_panel::{DebugPanel, ToggleFocus};
use editor::Editor;
use feature_flags::{Debugger, FeatureFlagViewExt};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagViewExt};
use gpui::{App, EntityInputHandler, actions};
use new_session_modal::NewSessionModal;
use project::debugger::{self, breakpoint_store::SourceBreakpoint};
@@ -47,7 +47,7 @@ pub fn init(cx: &mut App) {
return;
};
cx.when_flag_enabled::<Debugger>(window, |workspace, _, _| {
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |workspace, _, _| {
workspace
.register_action(|workspace, _: &ToggleFocus, window, cx| {
workspace.toggle_panel_focus::<DebugPanel>(window, cx);

View File

@@ -9,7 +9,7 @@ use util::ResultExt;
use workspace::{Member, Pane, PaneAxis, Workspace};
use crate::session::running::{
self, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
self, DebugTerminal, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
loaded_source_list::LoadedSourceList, module_list::ModuleList,
stack_frame_list::StackFrameList, variable_list::VariableList,
};
@@ -22,6 +22,7 @@ pub(crate) enum DebuggerPaneItem {
Frames,
Modules,
LoadedSources,
Terminal,
}
impl DebuggerPaneItem {
@@ -33,6 +34,7 @@ impl DebuggerPaneItem {
DebuggerPaneItem::Frames,
DebuggerPaneItem::Modules,
DebuggerPaneItem::LoadedSources,
DebuggerPaneItem::Terminal,
];
VARIANTS
}
@@ -55,6 +57,7 @@ impl DebuggerPaneItem {
DebuggerPaneItem::Frames => SharedString::new_static("Frames"),
DebuggerPaneItem::Modules => SharedString::new_static("Modules"),
DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"),
DebuggerPaneItem::Terminal => SharedString::new_static("Terminal"),
}
}
}
@@ -169,6 +172,7 @@ pub(crate) fn deserialize_pane_layout(
console: &Entity<Console>,
breakpoint_list: &Entity<BreakpointList>,
loaded_sources: &Entity<LoadedSourceList>,
terminal: &Entity<DebugTerminal>,
subscriptions: &mut HashMap<EntityId, Subscription>,
window: &mut Window,
cx: &mut Context<RunningState>,
@@ -191,6 +195,7 @@ pub(crate) fn deserialize_pane_layout(
console,
breakpoint_list,
loaded_sources,
terminal,
subscriptions,
window,
cx,
@@ -273,6 +278,13 @@ pub(crate) fn deserialize_pane_layout(
})),
cx,
)),
DebuggerPaneItem::Terminal => Box::new(SubView::new(
pane.focus_handle(cx),
terminal.clone().into(),
DebuggerPaneItem::Terminal,
None,
cx,
)),
})
.collect();

View File

@@ -104,6 +104,12 @@ impl DebugSession {
&self.mode
}
pub(crate) fn running_state(&self) -> Entity<RunningState> {
match &self.mode {
DebugSessionState::Running(running_state) => running_state.clone(),
}
}
pub(crate) fn label(&self, cx: &App) -> String {
if let Some(label) = self.label.get() {
return label.to_owned();

View File

@@ -27,6 +27,7 @@ use project::{
use rpc::proto::ViewId;
use settings::Settings;
use stack_frame_list::StackFrameList;
use terminal_view::TerminalView;
use ui::{
ActiveTheme, AnyElement, App, Context, ContextMenu, DropdownMenu, FluentBuilder,
InteractiveElement, IntoElement, Label, LabelCommon as _, ParentElement, Render, SharedString,
@@ -35,7 +36,7 @@ use ui::{
use util::ResultExt;
use variable_list::VariableList;
use workspace::{
ActivePaneDecorator, DraggedTab, Item, Member, Pane, PaneGroup, Workspace,
ActivePaneDecorator, DraggedTab, Item, ItemHandle, Member, Pane, PaneGroup, Workspace,
item::TabContentParams, move_item, pane::Event,
};
@@ -50,6 +51,7 @@ pub struct RunningState {
_subscriptions: Vec<Subscription>,
stack_frame_list: Entity<stack_frame_list::StackFrameList>,
loaded_sources_list: Entity<LoadedSourceList>,
pub debug_terminal: Entity<DebugTerminal>,
module_list: Entity<module_list::ModuleList>,
_console: Entity<Console>,
breakpoint_list: Entity<BreakpointList>,
@@ -364,6 +366,40 @@ pub(crate) fn new_debugger_pane(
ret
}
pub struct DebugTerminal {
pub terminal: Option<Entity<TerminalView>>,
focus_handle: FocusHandle,
}
impl DebugTerminal {
fn empty(cx: &mut Context<Self>) -> Self {
Self {
terminal: None,
focus_handle: cx.focus_handle(),
}
}
}
impl gpui::Render for DebugTerminal {
fn render(&mut self, _window: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
if let Some(terminal) = self.terminal.clone() {
terminal.into_any_element()
} else {
div().track_focus(&self.focus_handle).into_any_element()
}
}
}
impl Focusable for DebugTerminal {
fn focus_handle(&self, cx: &App) -> FocusHandle {
if let Some(terminal) = self.terminal.as_ref() {
return terminal.focus_handle(cx);
} else {
self.focus_handle.clone()
}
}
}
impl RunningState {
pub fn new(
session: Entity<Session>,
@@ -380,6 +416,8 @@ impl RunningState {
StackFrameList::new(workspace.clone(), session.clone(), weak_state, window, cx)
});
let debug_terminal = cx.new(DebugTerminal::empty);
let variable_list =
cx.new(|cx| VariableList::new(session.clone(), stack_frame_list.clone(), window, cx));
@@ -452,6 +490,7 @@ impl RunningState {
&console,
&breakpoint_list,
&loaded_source_list,
&debug_terminal,
&mut pane_close_subscriptions,
window,
cx,
@@ -494,6 +533,7 @@ impl RunningState {
breakpoint_list,
loaded_sources_list: loaded_source_list,
pane_close_subscriptions,
debug_terminal,
_schedule_serialize: None,
}
}
@@ -525,6 +565,90 @@ impl RunningState {
self.panes.pane_at_pixel_position(position).is_some()
}
fn create_sub_view(
&self,
item_kind: DebuggerPaneItem,
pane: &Entity<Pane>,
cx: &mut Context<Self>,
) -> Box<dyn ItemHandle> {
match item_kind {
DebuggerPaneItem::Console => {
let weak_console = self._console.clone().downgrade();
Box::new(SubView::new(
pane.focus_handle(cx),
self._console.clone().into(),
item_kind,
Some(Box::new(move |cx| {
weak_console
.read_with(cx, |console, cx| console.show_indicator(cx))
.unwrap_or_default()
})),
cx,
))
}
DebuggerPaneItem::Variables => Box::new(SubView::new(
self.variable_list.focus_handle(cx),
self.variable_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
self.breakpoint_list.focus_handle(cx),
self.breakpoint_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Frames => Box::new(SubView::new(
self.stack_frame_list.focus_handle(cx),
self.stack_frame_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Modules => Box::new(SubView::new(
self.module_list.focus_handle(cx),
self.module_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
self.loaded_sources_list.focus_handle(cx),
self.loaded_sources_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Terminal => Box::new(SubView::new(
self.debug_terminal.focus_handle(cx),
self.debug_terminal.clone().into(),
item_kind,
None,
cx,
)),
}
}
pub(crate) fn ensure_pane_item(
&mut self,
item_kind: DebuggerPaneItem,
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.pane_items_status(cx).get(&item_kind) == Some(&true) {
return;
};
let pane = self.panes.last_pane();
let sub_view = self.create_sub_view(item_kind, &pane, cx);
pane.update(cx, |pane, cx| {
pane.add_item_inner(sub_view, false, false, false, None, window, cx);
})
}
pub(crate) fn add_pane_item(
&mut self,
item_kind: DebuggerPaneItem,
@@ -538,58 +662,7 @@ impl RunningState {
);
if let Some(pane) = self.panes.pane_at_pixel_position(position) {
let sub_view = match item_kind {
DebuggerPaneItem::Console => {
let weak_console = self._console.clone().downgrade();
Box::new(SubView::new(
pane.focus_handle(cx),
self._console.clone().into(),
item_kind,
Some(Box::new(move |cx| {
weak_console
.read_with(cx, |console, cx| console.show_indicator(cx))
.unwrap_or_default()
})),
cx,
))
}
DebuggerPaneItem::Variables => Box::new(SubView::new(
self.variable_list.focus_handle(cx),
self.variable_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
self.breakpoint_list.focus_handle(cx),
self.breakpoint_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Frames => Box::new(SubView::new(
self.stack_frame_list.focus_handle(cx),
self.stack_frame_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Modules => Box::new(SubView::new(
self.module_list.focus_handle(cx),
self.module_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
self.loaded_sources_list.focus_handle(cx),
self.loaded_sources_list.clone().into(),
item_kind,
None,
cx,
)),
};
let sub_view = self.create_sub_view(item_kind, pane, cx);
pane.update(cx, |pane, cx| {
pane.add_item(sub_view, false, false, None, window, cx);

View File

@@ -1,4 +1,4 @@
use crate::{tests::start_debug_session, *};
use crate::{persistence::DebuggerPaneItem, tests::start_debug_session, *};
use dap::{
ErrorResponse, Message, RunInTerminalRequestArguments, SourceBreakpoint,
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
@@ -25,7 +25,7 @@ use std::{
atomic::{AtomicBool, Ordering},
},
};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
use terminal_view::terminal_panel::TerminalPanel;
use tests::{active_debug_session_panel, init_test, init_test_workspace};
use util::path;
use workspace::{Item, dock::Panel};
@@ -385,22 +385,17 @@ async fn test_handle_successful_run_in_terminal_reverse_request(
workspace
.update(cx, |workspace, _window, cx| {
let terminal_panel = workspace.panel::<TerminalPanel>(cx).unwrap();
let panel = terminal_panel.read(cx).pane().unwrap().read(cx);
assert_eq!(1, panel.items_len());
assert!(
panel
.active_item()
.unwrap()
.downcast::<TerminalView>()
.unwrap()
let debug_panel = workspace.panel::<DebugPanel>(cx).unwrap();
let session = debug_panel.read(cx).active_session().unwrap();
let running = session.read(cx).running_state();
assert_eq!(
running
.read(cx)
.terminal()
.read(cx)
.debug_terminal()
.pane_items_status(cx)
.get(&DebuggerPaneItem::Terminal),
Some(&true)
);
assert!(running.read(cx).debug_terminal.read(cx).terminal.is_some());
})
.unwrap();

View File

@@ -1,4 +1,4 @@
use feature_flags::{Debugger, FeatureFlagAppExt as _};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt as _};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
AnyElement, BackgroundExecutor, Entity, Focusable, FontWeight, ListSizingBehavior,
@@ -812,7 +812,7 @@ impl CodeActionContents {
actions: Option<Rc<[AvailableCodeAction]>>,
cx: &App,
) -> Self {
if !cx.has_flag::<Debugger>() {
if !cx.has_flag::<DebuggerFeatureFlag>() {
if let Some(tasks) = &mut tasks {
tasks
.templates

View File

@@ -71,25 +71,26 @@ use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layo
pub use element::{
CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition,
};
use feature_flags::{Debugger, FeatureFlagAppExt};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt};
use futures::{
FutureExt,
future::{self, Shared, join},
};
use fuzzy::StringMatchCandidate;
use ::git::Restore;
use ::git::blame::BlameEntry;
use ::git::{Restore, blame::ParsedCommitMessage};
use code_context_menus::{
AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu,
CompletionsMenu, ContextMenuOrigin,
};
use git::blame::{GitBlame, GlobalBlameRenderer};
use gpui::{
Action, Animation, AnimationExt, AnyElement, AnyWeakEntity, App, AppContext,
AsyncWindowContext, AvailableSpace, Background, Bounds, ClickEvent, ClipboardEntry,
ClipboardItem, Context, DispatchPhase, Edges, Entity, EntityInputHandler, EventEmitter,
FocusHandle, FocusOutEvent, Focusable, FontId, FontWeight, Global, HighlightStyle, Hsla,
KeyContext, Modifiers, MouseButton, MouseDownEvent, PaintQuad, ParentElement, Pixels, Render,
Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext,
AvailableSpace, Background, Bounds, ClickEvent, ClipboardEntry, ClipboardItem, Context,
DispatchPhase, Edges, Entity, EntityInputHandler, EventEmitter, FocusHandle, FocusOutEvent,
Focusable, FontId, FontWeight, Global, HighlightStyle, Hsla, KeyContext, Modifiers,
MouseButton, MouseDownEvent, PaintQuad, ParentElement, Pixels, Render, ScrollHandle,
SharedString, Size, Stateful, Styled, Subscription, Task, TextStyle, TextStyleRefinement,
UTF16Selection, UnderlineStyle, UniformListScrollHandle, WeakEntity, WeakFocusHandle, Window,
div, impl_actions, point, prelude::*, pulsating_between, px, relative, size,
@@ -117,6 +118,7 @@ use language::{
};
use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp};
use linked_editing_ranges::refresh_linked_ranges;
use markdown::Markdown;
use mouse_context_menu::MouseContextMenu;
use persistence::DB;
use project::{
@@ -798,6 +800,21 @@ impl ChangeList {
}
}
#[derive(Clone)]
struct InlineBlamePopoverState {
scroll_handle: ScrollHandle,
commit_message: Option<ParsedCommitMessage>,
markdown: Entity<Markdown>,
}
struct InlineBlamePopover {
position: gpui::Point<Pixels>,
show_task: Option<Task<()>>,
hide_task: Option<Task<()>>,
popover_bounds: Option<Bounds<Pixels>>,
popover_state: InlineBlamePopoverState,
}
/// Zed's primary implementation of text input, allowing users to edit a [`MultiBuffer`].
///
/// See the [module level documentation](self) for more information.
@@ -866,6 +883,7 @@ pub struct Editor {
context_menu_options: Option<ContextMenuOptions>,
mouse_context_menu: Option<MouseContextMenu>,
completion_tasks: Vec<(CompletionId, Task<Option<()>>)>,
inline_blame_popover: Option<InlineBlamePopover>,
signature_help_state: SignatureHelpState,
auto_signature_help: Option<bool>,
find_all_references_task_sources: Vec<Anchor>,
@@ -922,7 +940,6 @@ pub struct Editor {
show_git_blame_gutter: bool,
show_git_blame_inline: bool,
show_git_blame_inline_delay_task: Option<Task<()>>,
pub git_blame_inline_tooltip: Option<AnyWeakEntity>,
git_blame_inline_enabled: bool,
render_diff_hunk_controls: RenderDiffHunkControlsFn,
serialize_dirty_buffers: bool,
@@ -1665,6 +1682,7 @@ impl Editor {
context_menu_options: None,
mouse_context_menu: None,
completion_tasks: Default::default(),
inline_blame_popover: Default::default(),
signature_help_state: SignatureHelpState::default(),
auto_signature_help: None,
find_all_references_task_sources: Vec::new(),
@@ -1730,7 +1748,6 @@ impl Editor {
show_git_blame_inline: false,
show_selection_menu: None,
show_git_blame_inline_delay_task: None,
git_blame_inline_tooltip: None,
git_blame_inline_enabled: ProjectSettings::get_global(cx).git.inline_blame_enabled(),
render_diff_hunk_controls: Arc::new(render_diff_hunk_controls),
serialize_dirty_buffers: ProjectSettings::get_global(cx)
@@ -1806,6 +1823,7 @@ impl Editor {
);
});
editor.hide_signature_help(cx, SignatureHelpHiddenBy::Escape);
editor.inline_blame_popover.take();
}
}
EditorEvent::Edited { .. } => {
@@ -2603,6 +2621,7 @@ impl Editor {
self.update_visible_inline_completion(window, cx);
self.edit_prediction_requires_modifier_in_indent_conflict = true;
linked_editing_ranges::refresh_linked_ranges(self, window, cx);
self.inline_blame_popover.take();
if self.git_blame_inline_enabled {
self.start_inline_blame_timer(window, cx);
}
@@ -5140,7 +5159,7 @@ impl Editor {
Self::build_tasks_context(&project, &buffer, buffer_row, tasks, cx)
});
let debugger_flag = cx.has_flag::<Debugger>();
let debugger_flag = cx.has_flag::<DebuggerFeatureFlag>();
Some(cx.spawn_in(window, async move |editor, cx| {
let task_context = match task_context {
@@ -5483,6 +5502,82 @@ impl Editor {
}
}
fn show_blame_popover(
&mut self,
blame_entry: &BlameEntry,
position: gpui::Point<Pixels>,
cx: &mut Context<Self>,
) {
if let Some(state) = &mut self.inline_blame_popover {
state.hide_task.take();
cx.notify();
} else {
let delay = EditorSettings::get_global(cx).hover_popover_delay;
let show_task = cx.spawn(async move |editor, cx| {
cx.background_executor()
.timer(std::time::Duration::from_millis(delay))
.await;
editor
.update(cx, |editor, cx| {
if let Some(state) = &mut editor.inline_blame_popover {
state.show_task = None;
cx.notify();
}
})
.ok();
});
let Some(blame) = self.blame.as_ref() else {
return;
};
let blame = blame.read(cx);
let details = blame.details_for_entry(&blame_entry);
let markdown = cx.new(|cx| {
Markdown::new(
details
.as_ref()
.map(|message| message.message.clone())
.unwrap_or_default(),
None,
None,
cx,
)
});
self.inline_blame_popover = Some(InlineBlamePopover {
position,
show_task: Some(show_task),
hide_task: None,
popover_bounds: None,
popover_state: InlineBlamePopoverState {
scroll_handle: ScrollHandle::new(),
commit_message: details,
markdown,
},
});
}
}
fn hide_blame_popover(&mut self, cx: &mut Context<Self>) {
if let Some(state) = &mut self.inline_blame_popover {
if state.show_task.is_some() {
self.inline_blame_popover.take();
cx.notify();
} else {
let hide_task = cx.spawn(async move |editor, cx| {
cx.background_executor()
.timer(std::time::Duration::from_millis(100))
.await;
editor
.update(cx, |editor, cx| {
editor.inline_blame_popover.take();
cx.notify();
})
.ok();
});
state.hide_task = Some(hide_task);
}
}
}
fn refresh_document_highlights(&mut self, cx: &mut Context<Self>) -> Option<()> {
if self.pending_rename.is_some() {
return None;
@@ -9055,7 +9150,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
if !cx.has_flag::<Debugger>() {
if !cx.has_flag::<DebuggerFeatureFlag>() {
return;
}
let source = self
@@ -16657,12 +16752,7 @@ impl Editor {
pub fn render_git_blame_inline(&self, window: &Window, cx: &App) -> bool {
self.show_git_blame_inline
&& (self.focus_handle.is_focused(window)
|| self
.git_blame_inline_tooltip
.as_ref()
.and_then(|t| t.upgrade())
.is_some())
&& (self.focus_handle.is_focused(window) || self.inline_blame_popover.is_some())
&& !self.newest_selection_head_on_empty_line(cx)
&& self.has_blame_entries(cx)
}

View File

@@ -30,17 +30,21 @@ use crate::{
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use client::ParticipantIndex;
use collections::{BTreeMap, HashMap};
use feature_flags::{Debugger, FeatureFlagAppExt};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt};
use file_icons::FileIcons;
use git::{Oid, blame::BlameEntry, status::FileStatus};
use git::{
Oid,
blame::{BlameEntry, ParsedCommitMessage},
status::FileStatus,
};
use gpui::{
Action, Along, AnyElement, App, AvailableSpace, Axis as ScrollbarAxis, BorderStyle, Bounds,
ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, Element,
ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle,
Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges,
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
InteractiveElement, IntoElement, Keystroke, Length, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled,
TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
transparent_black,
};
@@ -49,6 +53,7 @@ use language::language_settings::{
IndentGuideBackgroundColoring, IndentGuideColoring, IndentGuideSettings, ShowWhitespaceSetting,
};
use lsp::DiagnosticSeverity;
use markdown::Markdown;
use multi_buffer::{
Anchor, ExcerptId, ExcerptInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint,
MultiBufferRow, RowInfo,
@@ -542,7 +547,7 @@ impl EditorElement {
register_action(editor, window, Editor::insert_uuid_v4);
register_action(editor, window, Editor::insert_uuid_v7);
register_action(editor, window, Editor::open_selections_in_multibuffer);
if cx.has_flag::<Debugger>() {
if cx.has_flag::<DebuggerFeatureFlag>() {
register_action(editor, window, Editor::toggle_breakpoint);
register_action(editor, window, Editor::edit_log_breakpoint);
register_action(editor, window, Editor::enable_breakpoint);
@@ -1749,6 +1754,7 @@ impl EditorElement {
content_origin: gpui::Point<Pixels>,
scroll_pixel_position: gpui::Point<Pixels>,
line_height: Pixels,
text_hitbox: &Hitbox,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
@@ -1780,21 +1786,13 @@ impl EditorElement {
padding * em_width
};
let workspace = editor.workspace()?.downgrade();
let blame_entry = blame
.update(cx, |blame, cx| {
blame.blame_for_rows(&[*row_info], cx).next()
})
.flatten()?;
let mut element = render_inline_blame_entry(
self.editor.clone(),
workspace,
&blame,
blame_entry,
&self.style,
cx,
)?;
let mut element = render_inline_blame_entry(blame_entry.clone(), &self.style, cx)?;
let start_y = content_origin.y
+ line_height * (display_row.as_f32() - scroll_pixel_position.y / line_height);
@@ -1820,11 +1818,122 @@ impl EditorElement {
};
let absolute_offset = point(start_x, start_y);
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
let bounds = Bounds::new(absolute_offset, size);
self.layout_blame_entry_popover(
bounds,
blame_entry,
blame,
line_height,
text_hitbox,
window,
cx,
);
element.prepaint_as_root(absolute_offset, AvailableSpace::min_size(), window, cx);
Some(element)
}
fn layout_blame_entry_popover(
&self,
parent_bounds: Bounds<Pixels>,
blame_entry: BlameEntry,
blame: Entity<GitBlame>,
line_height: Pixels,
text_hitbox: &Hitbox,
window: &mut Window,
cx: &mut App,
) {
let mouse_position = window.mouse_position();
let mouse_over_inline_blame = parent_bounds.contains(&mouse_position);
let mouse_over_popover = self.editor.update(cx, |editor, _| {
editor
.inline_blame_popover
.as_ref()
.and_then(|state| state.popover_bounds)
.map_or(false, |bounds| bounds.contains(&mouse_position))
});
self.editor.update(cx, |editor, cx| {
if mouse_over_inline_blame || mouse_over_popover {
editor.show_blame_popover(&blame_entry, mouse_position, cx);
} else {
editor.hide_blame_popover(cx);
}
});
let should_draw = self.editor.update(cx, |editor, _| {
editor
.inline_blame_popover
.as_ref()
.map_or(false, |state| state.show_task.is_none())
});
if should_draw {
let maybe_element = self.editor.update(cx, |editor, cx| {
editor
.workspace()
.map(|workspace| workspace.downgrade())
.zip(
editor
.inline_blame_popover
.as_ref()
.map(|p| p.popover_state.clone()),
)
.and_then(|(workspace, popover_state)| {
render_blame_entry_popover(
blame_entry,
popover_state.scroll_handle,
popover_state.commit_message,
popover_state.markdown,
workspace,
&blame,
window,
cx,
)
})
});
if let Some(mut element) = maybe_element {
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
let origin = self.editor.update(cx, |editor, _| {
let target_point = editor
.inline_blame_popover
.as_ref()
.map_or(mouse_position, |state| state.position);
let overall_height = size.height + HOVER_POPOVER_GAP;
let popover_origin = if target_point.y > overall_height {
point(target_point.x, target_point.y - size.height)
} else {
point(
target_point.x,
target_point.y + line_height + HOVER_POPOVER_GAP,
)
};
let horizontal_offset = (text_hitbox.top_right().x
- POPOVER_RIGHT_OFFSET
- (popover_origin.x + size.width))
.min(Pixels::ZERO);
point(popover_origin.x + horizontal_offset, popover_origin.y)
});
let popover_bounds = Bounds::new(origin, size);
self.editor.update(cx, |editor, _| {
if let Some(state) = &mut editor.inline_blame_popover {
state.popover_bounds = Some(popover_bounds);
}
});
window.defer_draw(element, origin, 2);
}
}
}
fn layout_blame_entries(
&self,
buffer_rows: &[RowInfo],
@@ -5851,24 +5960,35 @@ fn prepaint_gutter_button(
}
fn render_inline_blame_entry(
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
blame: &Entity<GitBlame>,
blame_entry: BlameEntry,
style: &EditorStyle,
cx: &mut App,
) -> Option<AnyElement> {
let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
renderer.render_inline_blame_entry(&style.text, blame_entry, cx)
}
fn render_blame_entry_popover(
blame_entry: BlameEntry,
scroll_handle: ScrollHandle,
commit_message: Option<ParsedCommitMessage>,
markdown: Entity<Markdown>,
workspace: WeakEntity<Workspace>,
blame: &Entity<GitBlame>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
let blame = blame.read(cx);
let details = blame.details_for_entry(&blame_entry);
let repository = blame.repository(cx)?.clone();
renderer.render_inline_blame_entry(
&style.text,
renderer.render_blame_entry_popover(
blame_entry,
details,
scroll_handle,
commit_message,
markdown,
repository,
workspace,
editor,
window,
cx,
)
}
@@ -6917,7 +7037,7 @@ impl Element for EditorElement {
let mut breakpoint_rows = self.editor.update(cx, |editor, cx| {
editor.active_breakpoints(start_row..end_row, window, cx)
});
if cx.has_flag::<Debugger>() {
if cx.has_flag::<DebuggerFeatureFlag>() {
for display_row in breakpoint_rows.keys() {
active_rows.entry(*display_row).or_default().breakpoint = true;
}
@@ -6940,7 +7060,7 @@ impl Element for EditorElement {
// We add the gutter breakpoint indicator to breakpoint_rows after painting
// line numbers so we don't paint a line number debug accent color if a user
// has their mouse over that line when a breakpoint isn't there
if cx.has_flag::<Debugger>() {
if cx.has_flag::<DebuggerFeatureFlag>() {
let gutter_breakpoint_indicator =
self.editor.read(cx).gutter_breakpoint_indicator.0;
if let Some((gutter_breakpoint_point, _)) =
@@ -7046,14 +7166,7 @@ impl Element for EditorElement {
blame.blame_for_rows(&[row_infos], cx).next()
})
.flatten()?;
let mut element = render_inline_blame_entry(
self.editor.clone(),
editor.workspace()?.downgrade(),
blame,
blame_entry,
&style,
cx,
)?;
let mut element = render_inline_blame_entry(blame_entry, &style, cx)?;
let inline_blame_padding = INLINE_BLAME_PADDING_EM_WIDTHS * em_advance;
Some(
element
@@ -7262,6 +7375,7 @@ impl Element for EditorElement {
content_origin,
scroll_pixel_position,
line_height,
&text_hitbox,
window,
cx,
);
@@ -7462,7 +7576,7 @@ impl Element for EditorElement {
let show_breakpoints = snapshot
.show_breakpoints
.unwrap_or(gutter_settings.breakpoints);
let breakpoints = if cx.has_flag::<Debugger>() && show_breakpoints {
let breakpoints = if cx.has_flag::<DebuggerFeatureFlag>() && show_breakpoints {
self.layout_breakpoints(
line_height,
start_row..end_row,

View File

@@ -7,10 +7,11 @@ use git::{
parse_git_remote_url,
};
use gpui::{
AnyElement, App, AppContext as _, Context, Entity, Hsla, Subscription, Task, TextStyle,
WeakEntity, Window,
AnyElement, App, AppContext as _, Context, Entity, Hsla, ScrollHandle, Subscription, Task,
TextStyle, WeakEntity, Window,
};
use language::{Bias, Buffer, BufferSnapshot, Edit};
use markdown::Markdown;
use multi_buffer::RowInfo;
use project::{
Project, ProjectItem,
@@ -98,10 +99,18 @@ pub trait BlameRenderer {
&self,
_: &TextStyle,
_: BlameEntry,
_: &mut App,
) -> Option<AnyElement>;
fn render_blame_entry_popover(
&self,
_: BlameEntry,
_: ScrollHandle,
_: Option<ParsedCommitMessage>,
_: Entity<Markdown>,
_: Entity<Repository>,
_: WeakEntity<Workspace>,
_: Entity<Editor>,
_: &mut Window,
_: &mut App,
) -> Option<AnyElement>;
@@ -139,10 +148,20 @@ impl BlameRenderer for () {
&self,
_: &TextStyle,
_: BlameEntry,
_: &mut App,
) -> Option<AnyElement> {
None
}
fn render_blame_entry_popover(
&self,
_: BlameEntry,
_: ScrollHandle,
_: Option<ParsedCommitMessage>,
_: Entity<Markdown>,
_: Entity<Repository>,
_: WeakEntity<Workspace>,
_: Entity<Editor>,
_: &mut Window,
_: &mut App,
) -> Option<AnyElement> {
None

View File

@@ -957,6 +957,7 @@ impl Item for Editor {
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
if matches!(event, workspace::Event::ModalOpened) {
editor.mouse_context_menu.take();
editor.inline_blame_popover.take();
}
})
.detach();

View File

@@ -10,7 +10,7 @@ use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
};
use agent::ThreadEvent;
use agent::{ContextLoadResult, ThreadEvent};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
@@ -115,7 +115,12 @@ impl ExampleContext {
pub fn push_user_message(&mut self, text: impl ToString) {
self.app
.update_entity(&self.agent_thread, |thread, cx| {
thread.insert_user_message(text.to_string(), vec![], None, cx);
thread.insert_user_message(
text.to_string(),
ContextLoadResult::default(),
None,
cx,
);
})
.unwrap();
}
@@ -253,6 +258,9 @@ impl ExampleContext {
}
});
}
ThreadEvent::InvalidToolInput { .. } => {
println!("{log_prefix} invalid tool input");
}
ThreadEvent::ToolConfirmationNeeded => {
panic!(
"{}Bug: Tool confirmation should not be required in eval",

View File

@@ -13,13 +13,11 @@ use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
mod add_arg_to_trait_method;
mod file_search;
mod no_empty_messages;
pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
let mut threads: Vec<Rc<dyn Example>> = vec![
Rc::new(file_search::FileSearchExample),
Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
Rc::new(no_empty_messages::NoEmptyMessagesExample),
];
for example_path in list_declarative_examples(examples_dir).unwrap() {

View File

@@ -1,62 +0,0 @@
use anyhow::Result;
use assistant_tools::{ListDirectoryToolInput, ReadFileToolInput};
use async_trait::async_trait;
use crate::example::{Example, ExampleContext, ExampleMetadata};
pub struct NoEmptyMessagesExample;
#[async_trait(?Send)]
impl Example for NoEmptyMessagesExample {
fn meta(&self) -> ExampleMetadata {
ExampleMetadata {
name: "no_empty_messages".to_string(),
url: "https://github.com/zed-industries/zed.git".to_string(),
revision: "fcfeea4825c563715bcd1a1af809d88a37d12ccb".to_string(),
language_server: None,
max_assertions: Some(3),
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
cx.push_user_message(format!(
r#"
Summarize all the files in crates/assistant/src.
Do tool calls only and do NOT include a description of what you are doing.
Just give me an explanation of what you did after you finished all tool calls
"#
));
let response = cx.run_turn().await?;
dbg!(response);
// let tool_use = response.expect_tool("list_directory", cx)?;
// let input = tool_use.parse_input::<ListDirectoryToolInput>()?;
// cx.assert(
// &input.path == "zed/crates/assistant/src",
// "Path matches directory",
// )?;
for file in [
"assistant.rs",
"assistant_configuration.rs",
"assistant_panel.rs",
"inline_assistant.rs",
"slash_command_settings.rs",
"terminal_inline_assistant.rs",
] {
let response = cx.run_turn().await?;
dbg!(response);
// let tool_use = response.expect_tool("read_file", cx)?;
// let input = tool_use.parse_input::<ReadFileToolInput>()?;
// dbg!(&input);
// cx.assert(
// input.path == format!("zed/crates/assistant/src/{file}"),
// format!("Path {} matches file {file}", input.path),
// )?;
}
cx.run_to_end().await?;
Ok(())
}
}

View File

@@ -218,8 +218,14 @@ impl ExampleInstance {
});
let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let prompt_store = None;
let thread_store = ThreadStore::load(
project.clone(),
tools,
prompt_store,
app_state.prompt_builder.clone(),
cx,
);
let meta = self.thread.meta();
let this = self.clone();

View File

@@ -64,23 +64,18 @@ impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
const NAME: &'static str = "predict-edits-rate-completions";
}
pub struct Remoting {}
impl FeatureFlag for Remoting {
const NAME: &'static str = "remoting";
}
pub struct LanguageModels {}
impl FeatureFlag for LanguageModels {
pub struct LanguageModelsFeatureFlag {}
impl FeatureFlag for LanguageModelsFeatureFlag {
const NAME: &'static str = "language-models";
}
pub struct LlmClosedBeta {}
impl FeatureFlag for LlmClosedBeta {
pub struct LlmClosedBetaFeatureFlag {}
impl FeatureFlag for LlmClosedBetaFeatureFlag {
const NAME: &'static str = "llm-closed-beta";
}
pub struct ZedPro {}
impl FeatureFlag for ZedPro {
pub struct ZedProFeatureFlag {}
impl FeatureFlag for ZedProFeatureFlag {
const NAME: &'static str = "zed-pro";
}
@@ -90,13 +85,13 @@ impl FeatureFlag for NotebookFeatureFlag {
const NAME: &'static str = "notebooks";
}
pub struct Debugger {}
impl FeatureFlag for Debugger {
pub struct DebuggerFeatureFlag {}
impl FeatureFlag for DebuggerFeatureFlag {
const NAME: &'static str = "debugger";
}
pub struct ThreadAutoCapture {}
impl FeatureFlag for ThreadAutoCapture {
pub struct ThreadAutoCaptureFeatureFlag {}
impl FeatureFlag for ThreadAutoCaptureFeatureFlag {
const NAME: &'static str = "thread-auto-capture";
fn enabled_for_staff() -> bool {

View File

@@ -1,19 +1,23 @@
use crate::{commit_tooltip::CommitTooltip, commit_view::CommitView};
use editor::{BlameRenderer, Editor};
use crate::{
commit_tooltip::{CommitAvatar, CommitDetails, CommitTooltip},
commit_view::CommitView,
};
use editor::{BlameRenderer, Editor, hover_markdown_style};
use git::{
blame::{BlameEntry, ParsedCommitMessage},
repository::CommitSummary,
};
use gpui::{
AnyElement, App, AppContext as _, ClipboardItem, Element as _, Entity, Hsla,
InteractiveElement as _, MouseButton, Pixels, StatefulInteractiveElement as _, Styled as _,
Subscription, TextStyle, WeakEntity, Window, div,
ClipboardItem, Entity, Hsla, MouseButton, ScrollHandle, Subscription, TextStyle, WeakEntity,
prelude::*,
};
use markdown::{Markdown, MarkdownElement};
use project::{git_store::Repository, project_settings::ProjectSettings};
use settings::Settings as _;
use ui::{
ActiveTheme, Color, ContextMenu, FluentBuilder as _, Icon, IconName, ParentElement as _, h_flex,
};
use theme::ThemeSettings;
use time::OffsetDateTime;
use time_format::format_local_timestamp;
use ui::{ContextMenu, Divider, IconButtonShape, prelude::*};
use workspace::Workspace;
const GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED: usize = 20;
@@ -115,10 +119,6 @@ impl BlameRenderer for GitBlameRenderer {
&self,
style: &TextStyle,
blame_entry: BlameEntry,
details: Option<ParsedCommitMessage>,
repository: Entity<Repository>,
workspace: WeakEntity<Workspace>,
editor: Entity<Editor>,
cx: &mut App,
) -> Option<AnyElement> {
let relative_timestamp = blame_entry_relative_timestamp(&blame_entry);
@@ -144,25 +144,223 @@ impl BlameRenderer for GitBlameRenderer {
.child(Icon::new(IconName::FileGit).color(Color::Hint))
.child(text)
.gap_2()
.hoverable_tooltip(move |_window, cx| {
let tooltip = cx.new(|cx| {
CommitTooltip::blame_entry(
&blame_entry,
details.clone(),
repository.clone(),
workspace.clone(),
cx,
)
});
editor.update(cx, |editor, _| {
editor.git_blame_inline_tooltip = Some(tooltip.downgrade().into())
});
tooltip.into()
})
.into_any(),
)
}
fn render_blame_entry_popover(
&self,
blame: BlameEntry,
scroll_handle: ScrollHandle,
details: Option<ParsedCommitMessage>,
markdown: Entity<Markdown>,
repository: Entity<Repository>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Option<AnyElement> {
let commit_time = blame
.committer_time
.and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok())
.unwrap_or(OffsetDateTime::now_utc());
let commit_details = CommitDetails {
sha: blame.sha.to_string().into(),
commit_time,
author_name: blame
.author
.clone()
.unwrap_or("<no name>".to_string())
.into(),
author_email: blame.author_mail.clone().unwrap_or("".to_string()).into(),
message: details,
};
let avatar = CommitAvatar::new(&commit_details).render(window, cx);
let author = commit_details.author_name.clone();
let author_email = commit_details.author_email.clone();
let short_commit_id = commit_details
.sha
.get(0..8)
.map(|sha| sha.to_string().into())
.unwrap_or_else(|| commit_details.sha.clone());
let full_sha = commit_details.sha.to_string().clone();
let absolute_timestamp = format_local_timestamp(
commit_details.commit_time,
OffsetDateTime::now_utc(),
time_format::TimestampFormat::MediumAbsolute,
);
let markdown_style = {
let mut style = hover_markdown_style(window, cx);
if let Some(code_block) = &style.code_block.text {
style.base_text_style.refine(code_block);
}
style
};
let message = commit_details
.message
.as_ref()
.map(|_| MarkdownElement::new(markdown.clone(), markdown_style).into_any())
.unwrap_or("<no commit message>".into_any());
let pull_request = commit_details
.message
.as_ref()
.and_then(|details| details.pull_request.clone());
let ui_font_size = ThemeSettings::get_global(cx).ui_font_size(cx);
let message_max_height = window.line_height() * 12 + (ui_font_size / 0.4);
let commit_summary = CommitSummary {
sha: commit_details.sha.clone(),
subject: commit_details
.message
.as_ref()
.map_or(Default::default(), |message| {
message
.message
.split('\n')
.next()
.unwrap()
.trim_end()
.to_string()
.into()
}),
commit_timestamp: commit_details.commit_time.unix_timestamp(),
has_parent: false,
};
let ui_font = ThemeSettings::get_global(cx).ui_font.clone();
// padding to avoid tooltip appearing right below the mouse cursor
// TODO: use tooltip_container here
Some(
div()
.pl_2()
.pt_2p5()
.child(
v_flex()
.elevation_2(cx)
.font(ui_font)
.text_ui(cx)
.text_color(cx.theme().colors().text)
.py_1()
.px_2()
.map(|el| {
el.occlude()
.on_mouse_move(|_, _, cx| cx.stop_propagation())
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
.child(
v_flex()
.w(gpui::rems(30.))
.gap_4()
.child(
h_flex()
.pb_1p5()
.gap_x_2()
.overflow_x_hidden()
.flex_wrap()
.children(avatar)
.child(author)
.when(!author_email.is_empty(), |this| {
this.child(
div()
.text_color(
cx.theme().colors().text_muted,
)
.child(author_email),
)
})
.border_b_1()
.border_color(cx.theme().colors().border_variant),
)
.child(
div()
.id("inline-blame-commit-message")
.child(message)
.max_h(message_max_height)
.overflow_y_scroll()
.track_scroll(&scroll_handle),
)
.child(
h_flex()
.text_color(cx.theme().colors().text_muted)
.w_full()
.justify_between()
.pt_1p5()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.child(absolute_timestamp)
.child(
h_flex()
.gap_1p5()
.when_some(pull_request, |this, pr| {
this.child(
Button::new(
"pull-request-button",
format!("#{}", pr.number),
)
.color(Color::Muted)
.icon(IconName::PullRequest)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.style(ButtonStyle::Subtle)
.on_click(move |_, _, cx| {
cx.stop_propagation();
cx.open_url(pr.url.as_str())
}),
)
})
.child(Divider::vertical())
.child(
Button::new(
"commit-sha-button",
short_commit_id.clone(),
)
.style(ButtonStyle::Subtle)
.color(Color::Muted)
.icon(IconName::FileGit)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(move |_, window, cx| {
CommitView::open(
commit_summary.clone(),
repository.downgrade(),
workspace.clone(),
window,
cx,
);
cx.stop_propagation();
}),
)
.child(
IconButton::new(
"copy-sha-button",
IconName::Copy,
)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| {
cx.stop_propagation();
cx.write_to_clipboard(
ClipboardItem::new_string(
full_sha.clone(),
),
)
}),
),
),
),
)
}),
)
.into_any_element(),
)
}
fn open_blame_commit(
&self,
blame_entry: BlameEntry,

View File

@@ -27,22 +27,18 @@ pub struct CommitDetails {
pub message: Option<ParsedCommitMessage>,
}
struct CommitAvatar<'a> {
pub struct CommitAvatar<'a> {
commit: &'a CommitDetails,
}
impl<'a> CommitAvatar<'a> {
fn new(details: &'a CommitDetails) -> Self {
pub fn new(details: &'a CommitDetails) -> Self {
Self { commit: details }
}
}
impl<'a> CommitAvatar<'a> {
fn render(
&'a self,
window: &mut Window,
cx: &mut Context<CommitTooltip>,
) -> Option<impl IntoElement + use<>> {
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> Option<impl IntoElement + use<>> {
let remote = self
.commit
.message

View File

@@ -1,12 +1,8 @@
mod supported_countries;
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
pub use supported_countries::*;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content(
@@ -338,7 +334,6 @@ pub struct CountTokensResponse {
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub raw_args: String,
pub args: serde_json::Value,
}

View File

@@ -1,232 +0,0 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by Google Gemini.
///
/// <https://ai.google.dev/gemini-api/docs/available-regions>
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by Google Gemini.
///
/// https://ai.google.dev/gemini-api/docs/available-regions
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"DZ", // Algeria
"AS", // American Samoa
"AO", // Angola
"AI", // Anguilla
"AQ", // Antarctica
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AW", // Aruba
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // The Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BM", // Bermuda
"BT", // Bhutan
"BO", // Bolivia
"BW", // Botswana
"BR", // Brazil
"IO", // British Indian Ocean Territory
"VG", // British Virgin Islands
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"BQ", // Caribbean Netherlands
"KY", // Cayman Islands
"CF", // Central African Republic
"TD", // Chad
"CL", // Chile
"CX", // Christmas Island
"CC", // Cocos (Keeling) Islands
"CO", // Colombia
"KM", // Comoros
"CK", // Cook Islands
"CI", // Côte d'Ivoire
"CR", // Costa Rica
"HR", // Croatia
"CW", // Curaçao
"CZ", // Czech Republic
"CD", // Democratic Republic of the Congo
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"ER", // Eritrea
"EE", // Estonia
"SZ", // Eswatini
"ET", // Ethiopia
"FK", // Falkland Islands (Islas Malvinas)
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GA", // Gabon
"GM", // The Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GI", // Gibraltar
"GR", // Greece
"GD", // Grenada
"GU", // Guam
"GT", // Guatemala
"GG", // Guernsey
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"HM", // Heard Island and McDonald Islands
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IM", // Isle of Man
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JE", // Jersey
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KG", // Kyrgyzstan
"KW", // Kuwait
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LY", // Libya
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"ML", // Mali
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MN", // Mongolia
"MS", // Montserrat
"MA", // Morocco
"MZ", // Mozambique
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NC", // New Caledonia
"NZ", // New Zealand
"NI", // Nicaragua
"NE", // Niger
"NG", // Nigeria
"NU", // Niue
"NF", // Norfolk Island
"MP", // Northern Mariana Islands
"NO", // Norway
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PN", // Pitcairn Islands
"PL", // Poland
"PT", // Portugal
"PR", // Puerto Rico
"QA", // Qatar
"CY", // Republic of Cyprus
"CG", // Republic of the Congo
"RO", // Romania
"RW", // Rwanda
"BL", // Saint Barthélemy
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"PM", // Saint Pierre and Miquelon
"VC", // Saint Vincent and the Grenadines
"SH", // Saint Helena, Ascension and Tristan da Cunha
"WS", // Samoa
"ST", // São Tomé and Príncipe
"SA", // Saudi Arabia
"SN", // Senegal
"SC", // Seychelles
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"SO", // Somalia
"ZA", // South Africa
"GS", // South Georgia and the South Sandwich Islands
"KR", // South Korea
"SS", // South Sudan
"ES", // Spain
"LK", // Sri Lanka
"SD", // Sudan
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste
"TG", // Togo
"TK", // Tokelau
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Türkiye
"TM", // Turkmenistan
"TC", // Turks and Caicos Islands
"TV", // Tuvalu
"UG", // Uganda
"GB", // United Kingdom
"AE", // United Arab Emirates
"US", // United States
"UM", // United States Minor Outlying Islands
"VI", // U.S. Virgin Islands
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VE", // Venezuela
"VN", // Vietnam
"WF", // Wallis and Futuna
"EH", // Western Sahara
"YE", // Yemen
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@@ -1,7 +1,7 @@
use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {

View File

@@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {
@@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String>>,
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
// Has complete token usage after the stream has finished
pub last_token_usage: Arc<Mutex<TokenUsage>>,
}
@@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
>;
fn stream_completion_with_usage(
&self,
@@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use collections::{HashSet, IndexMap};
use feature_flags::{Assistant2FeatureFlag, ZedPro};
use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
use gpui::{
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
@@ -584,7 +584,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.p_1()
.gap_4()
.justify_between()
.when(cx.has_flag::<ZedPro>(), |this| {
.when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
this.child(match plan {
Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
.icon(IconName::ZedAssistant)

View File

@@ -71,7 +71,7 @@ fn register_language_model_providers(
);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
cx.observe_flag::<feature_flags::LanguageModelsFeatureFlag, _>(move |enabled, cx| {
let user_store = user_store.clone();
let client = client.clone();
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {

View File

@@ -12,10 +12,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -27,7 +27,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe};
use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
@@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_anthropic(
request,
self.model.request_id().into(),
@@ -626,7 +631,7 @@ pub fn into_anthropic(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
let input_json = tool_use.input_json.trim();
let input_value = if input_json.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::default()))
} else {
serde_json::Value::from_str(input_json)
};
let event_result = match input_value {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
return Some((
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
input_json
)
.map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
},
raw_input: tool_use.input_json.clone(),
},
))
})],
state,
));
return Some((vec![event_result], state));
}
}
Event::MessageStart { message } => {
@@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
}
Event::Error { error } => {
return Some((
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error)
)))],
state,
));
}
_ => {}
},
Err(err) => {
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View File

@@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings
let region = state
@@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
_ => {}
},
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
}
}
None

View File

@@ -2,7 +2,7 @@ use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::{Client, UserStore, zed_urls};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
stream::BoxStream,
@@ -10,11 +10,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -324,7 +324,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
);
}
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
zed_cloud_provider_additional_models()
} else {
&[]
@@ -531,7 +531,7 @@ impl CloudLanguageModel {
{
request_builder.uri(completions_url)
} else {
request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
};
let request = request_builder
.header("Content-Type", "application/json")
@@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
@@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {
@@ -940,7 +945,7 @@ impl Render for ConfigurationView {
),
),
)
} else if cx.has_flag::<ZedPro>() {
} else if cx.has_flag::<ZedProFeatureFlag>() {
Some(
h_flex()
.gap_2()

View File

@@ -17,16 +17,16 @@ use gpui::{
Transformation, percentage, svg,
};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
vec![Err(anyhow!("Response contained no choices").into())],
state,
));
};
@@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
let Some(delta) = delta else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
vec![Err(anyhow!("Response contained no delta").into())],
state,
));
};
@@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
raw_input: tool_call.arguments.clone(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
|(_, tool_call)| match serde_json::Value::from_str(
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
));
@@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
}
}
@@ -453,9 +464,11 @@ impl CopilotChatLanguageModel {
}
}
messages.push(ChatMessage::User {
content: text_content,
});
if !text_content.is_empty() {
messages.push(ChatMessage::User {
content: text_content,
});
}
}
Role::Assistant => {
let mut tool_calls = Vec::new();

View File

@@ -9,9 +9,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_deepseek(
request,
self.model.id().to_string(),
@@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View File

@@ -11,8 +11,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -396,7 +404,6 @@ pub fn into_google(
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
function_call: google_ai::FunctionCall {
name: tool_use.name.to_string(),
raw_args: tool_use.raw_input,
args: tool_use.input,
},
}))
@@ -472,7 +479,7 @@ pub fn into_google(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
@@ -493,7 +500,7 @@ pub fn map_to_language_model_completion_events(
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
@@ -540,8 +547,8 @@ pub fn map_to_language_model_completion_events(
is_input_complete: true,
raw_input: function_call_part
.function_call
.raw_args
.clone(),
.args
.to_string(),
input: function_call_part.function_call.args,
},
)));
@@ -560,7 +567,10 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => {
return Some((vec![Err(anyhow!(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View File

@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
@@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

View File

@@ -8,9 +8,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use futures::stream::BoxStream;
@@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_mistral(
request,
self.model.id().to_string(),
@@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View File

@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

View File

@@ -9,10 +9,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use open_ai::{Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
@@ -24,7 +24,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe};
use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
@@ -419,7 +424,7 @@ pub fn into_open_ai(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))],
state,
));
};
@@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
raw_input: tool_call.arguments.clone(),
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
|(_, tool_call)| match serde_json::Value::from_str(
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
));
@@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
Err(err) => {
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
}
}
}

View File

@@ -84,7 +84,7 @@ fn get_max_tokens(name: &str) -> usize {
"mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
| "dolphin-mixtral" => 32768,
"llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
| "deepseek-coder-v2" | "deepseek-r1" | "yi-coder" => 128000,
| "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" => 128000,
_ => DEFAULT_TOKENS,
}
.clamp(1, MAXIMUM_TOKENS)

View File

@@ -1,5 +1,3 @@
mod supported_countries;
use anyhow::{Context as _, Result, anyhow};
use futures::{
AsyncBufReadExt, AsyncReadExt, StreamExt,
@@ -15,8 +13,6 @@ use std::{
};
use strum::EnumIter;
pub use supported_countries::*;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {

View File

@@ -1,238 +0,0 @@
use std::collections::HashSet;
use std::sync::LazyLock;
/// Returns whether the given country code is supported by OpenAI.
///
/// <https://platform.openai.com/docs/supported-countries>
pub fn is_supported_country(country_code: &str) -> bool {
SUPPORTED_COUNTRIES.contains(&country_code)
}
/// The list of country codes supported by OpenAI.
///
/// https://platform.openai.com/docs/supported-countries
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
vec![
"AL", // Albania
"DZ", // Algeria
"AS", // American Samoa (US)
"AI", // Anguilla (UK)
"AF", // Afghanistan
"AD", // Andorra
"AO", // Angola
"AG", // Antigua and Barbuda
"AR", // Argentina
"AM", // Armenia
"AU", // Australia
"AT", // Austria
"AZ", // Azerbaijan
"BS", // Bahamas
"BH", // Bahrain
"BD", // Bangladesh
"BB", // Barbados
"BE", // Belgium
"BZ", // Belize
"BJ", // Benin
"BM", // Bermuda (UK)
"BT", // Bhutan
"BO", // Bolivia
"BA", // Bosnia and Herzegovina
"BW", // Botswana
"BR", // Brazil
"IO", // British Indian Ocean Territory (UK)
"BN", // Brunei
"BG", // Bulgaria
"BF", // Burkina Faso
"BI", // Burundi
"CV", // Cabo Verde
"KH", // Cambodia
"CM", // Cameroon
"CA", // Canada
"KY", // Cayman Islands (UK)
"CF", // Central African Republic
"TD", // Chad
"CL", // Chile
"CX", // Christmas Island (AU)
"CC", // Cocos (Keeling) Islands (AU)
"CO", // Colombia
"KM", // Comoros
"CG", // Congo (Brazzaville)
"CD", // Congo (DRC)
"CK", // Cook Islands (NZ)
"CR", // Costa Rica
"CI", // Côte d'Ivoire
"HR", // Croatia
"CY", // Cyprus
"CZ", // Czechia (Czech Republic)
"DK", // Denmark
"DJ", // Djibouti
"DM", // Dominica
"DO", // Dominican Republic
"EC", // Ecuador
"EG", // Egypt
"SV", // El Salvador
"GQ", // Equatorial Guinea
"ER", // Eritrea
"EE", // Estonia
"SZ", // Eswatini (Swaziland)
"ET", // Ethiopia
"FK", // Falkland Islands (UK)
"FJ", // Fiji
"FI", // Finland
"FR", // France
"GF", // French Guiana (FR)
"PF", // French Polynesia (FR)
"TF", // French Southern Territories
"GA", // Gabon
"GM", // Gambia
"GE", // Georgia
"DE", // Germany
"GH", // Ghana
"GI", // Gibraltar (UK)
"GR", // Greece
"GD", // Grenada
"GT", // Guatemala
"GU", // Guam (US)
"GN", // Guinea
"GW", // Guinea-Bissau
"GY", // Guyana
"HT", // Haiti
"HM", // Heard Island and McDonald Islands (AU)
"VA", // Holy See (Vatican City)
"HN", // Honduras
"HU", // Hungary
"IS", // Iceland
"IN", // India
"ID", // Indonesia
"IQ", // Iraq
"IE", // Ireland
"IL", // Israel
"IT", // Italy
"JM", // Jamaica
"JP", // Japan
"JO", // Jordan
"KZ", // Kazakhstan
"KE", // Kenya
"KI", // Kiribati
"KW", // Kuwait
"KG", // Kyrgyzstan
"LA", // Laos
"LV", // Latvia
"LB", // Lebanon
"LS", // Lesotho
"LR", // Liberia
"LY", // Libya
"LI", // Liechtenstein
"LT", // Lithuania
"LU", // Luxembourg
"MG", // Madagascar
"MW", // Malawi
"MY", // Malaysia
"MV", // Maldives
"ML", // Mali
"MT", // Malta
"MH", // Marshall Islands
"MR", // Mauritania
"MU", // Mauritius
"MX", // Mexico
"FM", // Micronesia
"MD", // Moldova
"MC", // Monaco
"MN", // Mongolia
"MS", // Montserrat (UK)
"ME", // Montenegro
"MA", // Morocco
"MZ", // Mozambique
"MM", // Myanmar
"NA", // Namibia
"NR", // Nauru
"NP", // Nepal
"NL", // Netherlands
"NZ", // New Zealand
"NI", // Nicaragua
"NE", // Niger
"NG", // Nigeria
"NF", // Norfolk Island (AU)
"MK", // North Macedonia
"MI", // Northern Mariana Islands (UK)
"NO", // Norway
"NU", // Niue (NZ)
"OM", // Oman
"PK", // Pakistan
"PW", // Palau
"PS", // Palestine
"PA", // Panama
"PG", // Papua New Guinea
"PY", // Paraguay
"PE", // Peru
"PH", // Philippines
"PN", // Pitcairn (UK)
"PL", // Poland
"PT", // Portugal
"PR", // Puerto Rico (US)
"QA", // Qatar
"RO", // Romania
"RW", // Rwanda
"BL", // Saint Barthélemy (FR)
"KN", // Saint Kitts and Nevis
"LC", // Saint Lucia
"MF", // Saint Martin (FR)
"PM", // Saint Pierre and Miquelon (FR)
"VC", // Saint Vincent and the Grenadines
"WS", // Samoa
"SM", // San Marino
"ST", // Sao Tome and Principe
"SA", // Saudi Arabia
"SN", // Senegal
"RS", // Serbia
"SC", // Seychelles
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
"SL", // Sierra Leone
"SG", // Singapore
"SK", // Slovakia
"SI", // Slovenia
"SB", // Solomon Islands
"SO", // Somalia
"ZA", // South Africa
"KR", // South Korea
"SS", // South Sudan
"ES", // Spain
"LK", // Sri Lanka
"SR", // Suriname
"SE", // Sweden
"CH", // Switzerland
"SD", // Sudan
"TW", // Taiwan
"TJ", // Tajikistan
"TZ", // Tanzania
"TH", // Thailand
"TL", // Timor-Leste (East Timor)
"TG", // Togo
"TK", // Tokelau (NZ)
"TO", // Tonga
"TT", // Trinidad and Tobago
"TN", // Tunisia
"TR", // Turkey
"TM", // Turkmenistan
"TC", // Turks and Caicos Islands (UK)
"TV", // Tuvalu
"UG", // Uganda
"UA", // Ukraine (with certain exceptions)
"AE", // United Arab Emirates
"GB", // United Kingdom
"UM", // United States Minor Outlying Islands (US)
"US", // United States of America
"UY", // Uruguay
"UZ", // Uzbekistan
"VU", // Vanuatu
"VN", // Vietnam
"VI", // Virgin Islands (US)
"VG", // Virgin Islands (UK)
"WF", // Wallis and Futuna (FR)
"YE", // Yemen
"ZM", // Zambia
"ZW", // Zimbabwe
]
.into_iter()
.collect()
});

View File

@@ -31,14 +31,6 @@ pub enum TerminalKind {
Shell(Option<PathBuf>),
/// Run a task.
Task(SpawnInTerminal),
/// Run a debug terminal.
Debug {
command: Option<String>,
args: Vec<String>,
envs: HashMap<String, String>,
cwd: Option<Arc<Path>>,
title: Option<String>,
},
}
/// SshCommand describes how to connect to a remote server
@@ -105,7 +97,6 @@ impl Project {
self.active_project_directory(cx)
}
}
TerminalKind::Debug { cwd, .. } => cwd.clone(),
};
let mut settings_location = None;
@@ -209,7 +200,6 @@ impl Project {
this.active_project_directory(cx)
}
}
TerminalKind::Debug { cwd, .. } => cwd.clone(),
};
let ssh_details = this.ssh_details(cx);
@@ -243,7 +233,6 @@ impl Project {
};
let mut python_venv_activate_command = None;
let debug_terminal = matches!(kind, TerminalKind::Debug { .. });
let (spawn_task, shell) = match kind {
TerminalKind::Shell(_) => {
@@ -339,27 +328,6 @@ impl Project {
}
}
}
TerminalKind::Debug {
command,
args,
envs,
title,
..
} => {
env.extend(envs);
let shell = if let Some(program) = command {
Shell::WithArguments {
program,
args,
title_override: Some(title.unwrap_or("Debug Terminal".into()).into()),
}
} else {
settings.shell.clone()
};
(None, shell)
}
};
TerminalBuilder::new(
local_path.map(|path| path.to_path_buf()),
@@ -373,7 +341,6 @@ impl Project {
ssh_details.is_some(),
window,
completion_tx,
debug_terminal,
cx,
)
.map(|builder| {

View File

@@ -60,9 +60,7 @@ pub enum PromptId {
impl PromptId {
pub fn new() -> PromptId {
PromptId::User {
uuid: UserPromptId::new(),
}
UserPromptId::new().into()
}
pub fn is_built_in(&self) -> bool {
@@ -70,6 +68,12 @@ impl PromptId {
}
}
impl From<UserPromptId> for PromptId {
fn from(uuid: UserPromptId) -> Self {
PromptId::User { uuid }
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct UserPromptId(pub Uuid);
@@ -227,9 +231,7 @@ impl PromptStore {
.collect::<heed::Result<HashMap<_, _>>>()?;
for (prompt_id_v1, metadata_v1) in metadata_v1 {
let prompt_id_v2 = PromptId::User {
uuid: UserPromptId(prompt_id_v1.0),
};
let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
continue;
};

View File

@@ -1,5 +1,5 @@
[package]
name = "prompt_library"
name = "rules_library"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/prompt_library.rs"
path = "src/rules_library.rs"
[dependencies]
anyhow.workspace = true

View File

@@ -352,7 +352,6 @@ impl TerminalBuilder {
is_ssh_terminal: bool,
window: AnyWindowHandle,
completion_tx: Sender<Option<ExitStatus>>,
debug_terminal: bool,
cx: &App,
) -> Result<TerminalBuilder> {
// If the parent environment doesn't have a locale set
@@ -502,7 +501,6 @@ impl TerminalBuilder {
word_regex: RegexSearch::new(WORD_REGEX).unwrap(),
python_file_line_regex: RegexSearch::new(PYTHON_FILE_LINE_REGEX).unwrap(),
vi_mode_enabled: false,
debug_terminal,
is_ssh_terminal,
python_venv_directory,
};
@@ -660,7 +658,6 @@ pub struct Terminal {
python_file_line_regex: RegexSearch,
task: Option<TaskState>,
vi_mode_enabled: bool,
debug_terminal: bool,
is_ssh_terminal: bool,
}
@@ -1855,10 +1852,6 @@ impl Terminal {
self.task.as_ref()
}
pub fn debug_terminal(&self) -> bool {
self.debug_terminal
}
pub fn wait_for_completed_task(&self, cx: &App) -> Task<Option<ExitStatus>> {
if let Some(task) = self.task() {
if task.status == TaskStatus::Running {

View File

@@ -1420,9 +1420,6 @@ impl Item for TerminalView {
}
}
},
None if self.terminal.read(cx).debug_terminal() => {
(IconName::Debug, Color::Muted, None)
}
None => (IconName::Terminal, Color::Muted, None),
};
@@ -1583,7 +1580,7 @@ impl SerializableItem for TerminalView {
cx: &mut Context<Self>,
) -> Option<Task<gpui::Result<()>>> {
let terminal = self.terminal().read(cx);
if terminal.task().is_some() || terminal.debug_terminal() {
if terminal.task().is_some() {
return None;
}

View File

@@ -18,7 +18,7 @@ use crate::platforms::{platform_linux, platform_mac, platform_windows};
use auto_update::AutoUpdateStatus;
use call::ActiveCall;
use client::{Client, UserStore};
use feature_flags::{FeatureFlagAppExt, ZedPro};
use feature_flags::{FeatureFlagAppExt, ZedProFeatureFlag};
use gpui::{
Action, AnyElement, App, Context, Corner, Decorations, Element, Entity, InteractiveElement,
Interactivity, IntoElement, MouseButton, ParentElement, Render, Stateful,
@@ -663,7 +663,7 @@ impl TitleBar {
.anchor(Corner::TopRight)
.menu(move |window, cx| {
ContextMenu::build(window, cx, |menu, _, cx| {
menu.when(cx.has_flag::<ZedPro>(), |menu| {
menu.when(cx.has_flag::<ZedProFeatureFlag>(), |menu| {
menu.action(
format!(
"Current Plan: {}",

View File

@@ -142,6 +142,10 @@ impl PaneGroup {
self.root.first_pane()
}
pub fn last_pane(&self) -> Entity<Pane> {
self.root.last_pane()
}
pub fn find_pane_in_direction(
&mut self,
active_pane: &Entity<Pane>,
@@ -360,6 +364,13 @@ impl Member {
}
}
fn last_pane(&self) -> Entity<Pane> {
match self {
Member::Axis(axis) => axis.members.last().unwrap().last_pane(),
Member::Pane(pane) => pane.clone(),
}
}
pub fn render(
&self,
basis: usize,

View File

@@ -20,7 +20,7 @@ use command_palette_hooks::CommandPaletteFilter;
use debugger_ui::debugger_panel::DebugPanel;
use editor::ProposedChangesEditorToolbar;
use editor::{Editor, MultiBuffer, scroll::Autoscroll};
use feature_flags::{Debugger, FeatureFlagAppExt, FeatureFlagViewExt};
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt, FeatureFlagViewExt};
use futures::{StreamExt, channel::mpsc, select_biased};
use git_ui::git_panel::GitPanel;
use git_ui::project_diff::ProjectDiffToolbar;
@@ -279,7 +279,7 @@ fn feature_gate_zed_pro_actions(cx: &mut App) {
filter.hide_action_types(&zed_pro_actions);
});
cx.observe_flag::<feature_flags::ZedPro, _>({
cx.observe_flag::<feature_flags::ZedProFeatureFlag, _>({
move |is_enabled, cx| {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
if is_enabled {
@@ -439,7 +439,7 @@ fn initialize_panels(
workspace.add_panel(channels_panel, window, cx);
workspace.add_panel(chat_panel, window, cx);
workspace.add_panel(notification_panel, window, cx);
cx.when_flag_enabled::<Debugger>(window, |_, window, cx| {
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |_, window, cx| {
cx.spawn_in(
window,
async move |workspace: gpui::WeakEntity<Workspace>,

View File

@@ -199,14 +199,14 @@ pub mod assistant {
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct OpenPromptLibrary {
pub struct OpenRulesLibrary {
#[serde(skip)]
pub prompt_to_select: Option<Uuid>,
}
impl_action_with_deprecated_aliases!(
assistant,
OpenPromptLibrary,
OpenRulesLibrary,
["assistant::DeployPromptLibrary"]
);

View File

@@ -2178,7 +2178,7 @@ Examples:
## Show Whitespaces
- Description: Whether or not to show render whitespace characters in the editor.
- Description: Whether or not to render whitespace characters in the editor.
- Setting: `show_whitespaces`
- Default: `selection`

Some files were not shown because too many files have changed in this diff Show More