Compare commits

..

8 Commits

Author SHA1 Message Date
Richard Feldman
d42ce71772 Retry if language server hasn't started up yet. 2025-04-17 11:09:03 -04:00
Richard Feldman
82511d4300 Print failrues and successes 2025-04-17 11:05:25 -04:00
Antonio Scandurra
46a7cd93d9 Load grammars in the eval
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-04-16 13:43:11 -07:00
Antonio Scandurra
d253889fe3 Checkpoint
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-04-16 13:24:43 -07:00
Antonio Scandurra
29149c2eb5 Revamp tools and system prompt
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-04-16 12:40:58 -07:00
Antonio Scandurra
853f47b9a2 Checkpoint 2025-04-16 10:42:36 -07:00
Nathan Sobo
86dbbdc921 Match full paths against glob patterns in path search tool
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-04-15 21:33:56 -07:00
Nathan Sobo
d78cf50efb Output repository diff in eval example log
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-04-15 21:33:32 -07:00
102 changed files with 1423 additions and 3627 deletions

120
Cargo.lock generated
View File

@@ -324,7 +324,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.26.3",
"strum",
"thiserror 2.0.12",
"workspace-hack",
]
@@ -567,7 +567,7 @@ dependencies = [
"settings",
"smallvec",
"smol",
"strum 0.26.3",
"strum",
"telemetry_events",
"text",
"theme",
@@ -704,7 +704,6 @@ dependencies = [
"assistant_tool",
"chrono",
"collections",
"feature_flags",
"futures 0.3.31",
"gpui",
"html_to_markdown",
@@ -719,14 +718,13 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"settings",
"ui",
"unindent",
"util",
"web_search",
"workspace",
"workspace-hack",
"worktree",
"zed_llm_client",
]
[[package]]
@@ -1884,7 +1882,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.26.3",
"strum",
"thiserror 2.0.12",
"tokio",
"workspace-hack",
@@ -3031,7 +3029,7 @@ dependencies = [
"settings",
"sha2",
"sqlx",
"strum 0.26.3",
"strum",
"subtle",
"supermaven_api",
"telemetry_events",
@@ -3363,7 +3361,7 @@ dependencies = [
"serde",
"serde_json",
"settings",
"strum 0.26.3",
"strum",
"task",
"theme",
"ui",
@@ -4480,7 +4478,7 @@ dependencies = [
"optfield",
"proc-macro2",
"quote",
"strum 0.26.3",
"strum",
"syn 2.0.100",
]
@@ -4889,7 +4887,6 @@ dependencies = [
"collections",
"context_server",
"dap",
"dirs 5.0.1",
"env_logger 0.11.8",
"extension",
"fs",
@@ -4909,13 +4906,12 @@ dependencies = [
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"settings",
"shellexpand 2.1.2",
"telemetry",
"toml 0.8.20",
"unindent",
"util",
"uuid",
"workspace-hack",
]
@@ -5125,7 +5121,7 @@ dependencies = [
"serde",
"settings",
"smallvec",
"strum 0.26.3",
"strum",
"telemetry",
"theme",
"ui",
@@ -5976,7 +5972,7 @@ dependencies = [
"serde_derive",
"serde_json",
"settings",
"strum 0.26.3",
"strum",
"telemetry",
"theme",
"time",
@@ -6069,7 +6065,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.26.3",
"strum",
"workspace-hack",
]
@@ -6175,7 +6171,7 @@ dependencies = [
"slotmap",
"smallvec",
"smol",
"strum 0.26.3",
"strum",
"sum_tree",
"taffy",
"thiserror 2.0.12",
@@ -6823,7 +6819,7 @@ name = "icons"
version = "0.1.0"
dependencies = [
"serde",
"strum 0.26.3",
"strum",
"workspace-hack",
]
@@ -7091,7 +7087,7 @@ dependencies = [
"paths",
"pretty_assertions",
"serde",
"strum 0.26.3",
"strum",
"util",
"workspace-hack",
]
@@ -7677,7 +7673,7 @@ dependencies = [
"serde",
"serde_json",
"smol",
"strum 0.26.3",
"strum",
"telemetry_events",
"thiserror 2.0.12",
"util",
@@ -7737,7 +7733,7 @@ dependencies = [
"serde_json",
"settings",
"smol",
"strum 0.26.3",
"strum",
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
@@ -7745,7 +7741,6 @@ dependencies = [
"ui",
"util",
"workspace-hack",
"zed_llm_client",
]
[[package]]
@@ -8710,7 +8705,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.26.3",
"strum",
"workspace-hack",
]
@@ -9557,7 +9552,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.26.3",
"strum",
"workspace-hack",
]
@@ -12136,7 +12131,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"strum 0.26.3",
"strum",
"tracing",
"util",
"workspace-hack",
@@ -12664,7 +12659,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"strum 0.26.3",
"strum",
"thiserror 2.0.12",
"time",
"tracing",
@@ -13709,7 +13704,7 @@ dependencies = [
"settings",
"simplelog",
"story",
"strum 0.26.3",
"strum",
"theme",
"title_bar",
"ui",
@@ -13791,16 +13786,7 @@ version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros 0.26.4",
]
[[package]]
name = "strum"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
dependencies = [
"strum_macros 0.27.1",
"strum_macros",
]
[[package]]
@@ -13816,19 +13802,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "strum_macros"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.100",
]
[[package]]
name = "subtle"
version = "2.6.1"
@@ -14444,7 +14417,7 @@ dependencies = [
"serde_json_lenient",
"serde_repr",
"settings",
"strum 0.26.3",
"strum",
"thiserror 2.0.12",
"util",
"uuid",
@@ -14478,7 +14451,7 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"simplelog",
"strum 0.26.3",
"strum",
"theme",
"vscode_theme",
"workspace-hack",
@@ -15479,7 +15452,7 @@ dependencies = [
"settings",
"smallvec",
"story",
"strum 0.26.3",
"strum",
"theme",
"ui_macros",
"util",
@@ -16612,36 +16585,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web_search"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"gpui",
"serde",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "web_search_providers"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
"language_model",
"serde",
"serde_json",
"web_search",
"workspace-hack",
"zed_llm_client",
]
[[package]]
name = "webpki-root-certs"
version = "0.26.8"
@@ -17680,7 +17623,7 @@ dependencies = [
"settings",
"smallvec",
"sqlez",
"strum 0.26.3",
"strum",
"task",
"telemetry",
"tempfile",
@@ -17825,7 +17768,7 @@ dependencies = [
"sqlx-macros-core",
"sqlx-postgres",
"sqlx-sqlite",
"strum 0.26.3",
"strum",
"subtle",
"syn 1.0.109",
"syn 2.0.100",
@@ -18197,7 +18140,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.183.3"
version = "0.183.0"
dependencies = [
"activity_indicator",
"agent",
@@ -18320,8 +18263,6 @@ dependencies = [
"uuid",
"vim",
"vim_mode_setting",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
@@ -18386,13 +18327,12 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.5.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4"
dependencies = [
"serde",
"serde_json",
"strum 0.27.1",
"uuid",
]

View File

@@ -165,8 +165,6 @@ members = [
"crates/util_macros",
"crates/vim",
"crates/vim_mode_setting",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
@@ -372,8 +370,6 @@ util = { path = "crates/util" }
util_macros = { path = "crates/util_macros" }
vim = { path = "crates/vim" }
vim_mode_setting = { path = "crates/vim_mode_setting" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
@@ -605,7 +601,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.5.0"
zed_llm_client = "0.4"
zstd = "0.11"
metal = "0.29"

View File

@@ -630,7 +630,6 @@
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "assistant::OpenPromptLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",

View File

@@ -286,7 +286,6 @@
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "assistant::OpenPromptLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",

View File

@@ -1,148 +1,65 @@
You are an AI assistant integrated into a code editor. You have the programming ability of an expert programmer who takes pride in writing high-quality code and is driven to the point of obsession about solving problems effectively. Your goal is to do one of the following two things:
You are a powerful agentic AI coding assistant. You operate exclusively in Zed, the world's best IDE.
1. Help users answer questions and perform tasks related to their codebase.
2. Answer general-purpose questions unrelated to their particular codebase.
You are pair programming with a USER to solve their coding task.
The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.
Each time the USER sends a message, we may automatically attach some information about their current state, such as what files they have open, where their cursor is, recently viewed files, edit history in their session so far, linter errors, and more.
This information may or may not be relevant to the coding task, it is up for you to decide.
Your main goal is to follow the USER's instructions at each message.
It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding.
<communication>
1. Be conversational but professional.
2. Refer to the USER in the second person and yourself in the first person.
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. Use \( and \) for inline math, \[ and \] for block math.
4. NEVER lie or make things up.
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
</communication>
You should only perform actions that modify the user's system if explicitly requested by the user:
- If the user asks a question about how to accomplish a task, provide guidance or information, and use read-only tools (e.g., search) to assist. You may suggest potential actions, but do not directly modify the user's system without explicit instruction.
- If the user clearly requests that you perform an action, carry out the action directly without explaining why you are doing so.
<tool_calling>
You have tools at your disposal to solve the coding task. Follow these rules regarding tool calls:
1. ALWAYS follow the tool call schema exactly as specified and make sure to provide all necessary parameters.
2. The conversation may reference tools that are no longer available. NEVER call tools that are not explicitly provided.
3. **NEVER refer to tool names when speaking to the USER.** For example, instead of saying 'I need to use the edit_file tool to edit your file', just say 'I will edit your file'.
4. Only calls tools when they are necessary. If the USER's task is general or you already know the answer, just respond without calling tools.
5. Before calling each tool, first explain to the USER why you are calling it.
</tool_calling>
When answering questions, it's okay to give incomplete examples containing comments about what would go there in a real version. When being asked to directly perform tasks on the code base, you must ALWAYS make fully working code. You may never "simplify" the code by omitting or deleting functionality you know the user has requested, and you must NEVER write comments like "in a full version, this would..." - instead, you must actually implement the real version. Don't be lazy!
<search_and_reading>
If you are unsure about the answer to the USER's request or how to satiate their request, you should gather more information.
This can be done with additional tool calls, asking clarifying questions, etc...
Note that project files are automatically backed up. The user can always get them back later if anything goes wrong, so there's
no need to create backup files (e.g. `.bak` files) because these files will just take up unnecessary space on the user's disk.
For example, if you've performed a semantic search, and the results may not fully answer the USER's request, or merit gathering more information, feel free to call more tools.
Similarly, if you've performed an edit that may partially satiate the USER's query, but you're not confident, gather more information or use more tools
before ending your turn.
When attempting to resolve issues around failing tests, never simply remove the failing tests. Unless the user explicitly asks you to remove tests, ALWAYS attempt to fix the code causing the tests to fail.
Bias towards not asking the user for help if you can find the answer yourself.
</search_and_reading>
Ignore "TODO"-type comments unless they're relevant to the user's explicit request or the user specifically asks you to address them. It is, however, okay to include them in codebase summaries.
<making_code_changes>
When making code changes, NEVER output code to the USER, unless requested. Instead use one of the code edit tools to implement the change.
Use the code edit tools at most once per turn.
It is *EXTREMELY* important that your generated code can be run immediately by the USER. To ensure this, follow these instructions carefully:
1. Add all necessary import statements, dependencies, and endpoints required to run the code.
2. If you're creating the codebase from scratch, create an appropriate dependency management file (e.g. requirements.txt) with package versions and a helpful README.
3. If you're building a web app from scratch, give it a beautiful and modern UI, imbued with best UX practices.
4. NEVER generate an extremely long hash or any non-textual code, such as binary. These are not helpful to the USER and are very expensive.
5. Unless you are appending some small easy to apply edit to a file, or creating a new file, you MUST read the the contents or section of what you're editing before editing it.
6. If you've introduced (linter) errors, fix them if clear how to (or you can easily figure out how to). Do not make uneducated guesses. And DO NOT loop more than 3 times on fixing linter errors on the same file. On the third time, you should stop and ask the user what to do next.
7. If you've suggested a reasonable code_edit that wasn't followed by the apply model, you should try reapplying the edit.
</making_code_changes>
<style>
Editing code:
- Make sure to take previous edits into account.
- The edits you perform might lead to errors or warnings. At the end of your changes, check whether you introduced any problems, and fix them before providing a summary of the changes you made.
- You may only attempt to fix these up to 3 times. If you have tried 3 times to fix them, and there are still problems remaining, you must not continue trying to fix them, and must instead tell the user that there are problems remaining - and ask if the user would like you to attempt to solve them further.
- Do not fix errors unrelated to your changes unless the user explicitly asks you to do so.
- Prefer to move files over recreating them. The move can be followed by minor edits if required.
- If you seem to be stuck, never go back and "simplify the implementation" by deleting the parts of the implementation you're stuck on and replacing them with comments. If you ever feel the urge to do this, instead immediately stop whatever you're doing (even if the code is in a broken state), report that you are stuck, explain what you're stuck on, and ask the user how to proceed.
<debugging>
When debugging, only make code changes if you are certain that you can solve the problem.
Otherwise, follow debugging best practices:
1. Address the root cause instead of the symptoms.
2. Add descriptive logging statements and error messages to track variable and code state.
3. Add test functions and statements to isolate the problem.
</debugging>
Tool use:
- Make sure to adhere to the tools schema.
- Provide every required argument.
- DO NOT use tools to access items that are already available in the context section.
- Use only the tools that are currently available.
- DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
Responding:
- Be concise and direct in your responses.
- Never apologize or thank the user.
- Don't comment that you have just realized or understood something.
- When you are going to make a tool call, tersely explain your reasoning for choosing to use that tool, with no flourishes or commentary beyond that information.
For example, rather than saying "You're absolutely right! Thank you for providing that context. Now I understand that we're missing a dependency, and I need to add it:" say "I'll add that missing dependency:" instead.
- Also, don't restate what a tool call is about to do (or just did).
For example, don't say "Now I'm going to check diagnostics to see if there are any warnings or errors," followed by running a tool which checks diagnostics and reports warnings or errors; instead, just request the tool call without saying anything.
- All tool results are provided to you automatically, so DO NOT thank the user when this happens.
Whenever you mention a code block, you MUST use ONLY the following format:
```language path/to/Something.blah#L123-456
(code goes here)
```
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
</style>
<calling_external_apis>
1. Unless explicitly requested by the USER, use the best suited external APIs and packages to solve the task. There is no need to ask the USER for permission.
2. When selecting which version of an API or package to use, choose one that is compatible with the USER's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
3. If an external API requires an API Key, be sure to point this out to the USER. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
</calling_external_apis>
The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files:

View File

@@ -652,35 +652,33 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"thinking": true,
"web_search": true
"thinking": true
}
},
"write": {
"name": "Write",
"enable_all_context_servers": true,
"tools": {
"terminal": true,
"batch_tool": true,
"code_actions": true,
"code_symbols": true,
"contents": true,
"batch_tool": false,
"code_actions": false,
"code_symbols": false,
"contents": false,
"copy_path": false,
"create_file": true,
"delete_path": false,
"diagnostics": true,
"find_replace_file": true,
"edit_file": true,
"fetch": true,
"list_directory": false,
"list_directory": true,
"move_path": false,
"now": true,
"now": false,
"path_search": true,
"read_file": true,
"regex_search": true,
"rename": true,
"symbol_info": true,
"thinking": true,
"web_search": true
"rename": false,
"symbol_info": false,
"terminal": true,
"thinking": true
}
}
},

View File

@@ -1,30 +1,27 @@
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context::{AssistantContext, ContextId};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use assistant_tool::ToolUseStatus;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer};
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
use gpui::{
AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem,
DefiniteLength, EdgesRefinement, Empty, Entity, EventEmitter, Focusable, Hsla, ListAlignment,
ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription,
Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState,
MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task,
TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle,
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, Role, StopReason,
};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
@@ -684,9 +681,6 @@ fn open_markdown_link(
struct EditMessageState {
editor: Entity<Editor>,
last_estimated_token_count: Option<usize>,
_subscription: Subscription,
_update_token_count_task: Option<Task<anyhow::Result<()>>>,
}
impl ActiveThread {
@@ -786,13 +780,6 @@ impl ActiveThread {
self.last_error.take();
}
/// Returns the editing message id and the estimated token count in the content
pub fn editing_message_id(&self) -> Option<(MessageId, usize)> {
self.editing_message
.as_ref()
.map(|(id, state)| (*id, state.last_estimated_token_count.unwrap_or(0)))
}
fn push_message(
&mut self,
id: &MessageId,
@@ -956,8 +943,8 @@ impl ActiveThread {
&tool_use.input,
self.thread
.read(cx)
.output_for_tool(&tool_use.id)
.map(|output| output.clone().into())
.tool_result(&tool_use.id)
.map(|result| result.content.clone().into())
.unwrap_or("".into()),
cx,
);
@@ -1138,91 +1125,15 @@ impl ActiveThread {
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
editor
});
let subscription = cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.update_editing_message_token_count(true, cx);
}
_ => {}
});
self.editing_message = Some((
message_id,
EditMessageState {
editor: editor.clone(),
last_estimated_token_count: None,
_subscription: subscription,
_update_token_count_task: None,
},
));
self.update_editing_message_token_count(false, cx);
cx.notify();
}
fn update_editing_message_token_count(&mut self, debounce: bool, cx: &mut Context<Self>) {
let Some((message_id, state)) = self.editing_message.as_mut() else {
return;
};
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
state.last_estimated_token_count.take();
return;
};
let editor = state.editor.clone();
let thread = self.thread.clone();
let message_id = *message_id;
state._update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
cx.background_executor()
.timer(Duration::from_millis(200))
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = thread.read(cx).context_for_message(message_id);
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
return None;
}
let request = language_model::LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
tools: vec![],
stop: vec![],
temperature: None,
};
Some(default_model.model.count_tokens(request, cx))
})? {
task.await?
} else {
0
};
this.update(cx, |this, cx| {
let Some((_message_id, state)) = this.editing_message.as_mut() else {
return;
};
state.last_estimated_token_count = Some(token_count);
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
})
}));
}
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
@@ -1764,9 +1675,6 @@ impl ActiveThread {
"confirm-edit-message",
"Regenerate",
)
.disabled(
edit_message_editor.read(cx).is_empty(cx),
)
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
@@ -1829,16 +1737,8 @@ impl ActiveThread {
),
};
let after_editing_message = self
.editing_message
.as_ref()
.map_or(false, |(editing_message_id, _)| {
message_id > *editing_message_id
});
v_flex()
.w_full()
.when(after_editing_message, |parent| parent.opacity(0.2))
.when_some(checkpoint, |parent, checkpoint| {
let mut is_pending = false;
let mut error = None;
@@ -2379,15 +2279,12 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, cx);
}
let is_open = self
.expanded_tool_uses
.get(&tool_use.id)
.copied()
.unwrap_or_default();
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
let fs = self
@@ -2446,9 +2343,6 @@ impl ActiveThread {
rendered.input.clone(),
tool_use_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
@@ -2475,16 +2369,12 @@ impl ActiveThread {
rendered.output.clone(),
tool_use_markdown_style(window, cx),
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
})
.on_url_click({
let workspace = self.workspace.clone();
move |text, window, cx| {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
}),
)),
),
@@ -2541,7 +2431,6 @@ impl ActiveThread {
open_markdown_link(text, workspace.clone(), window, cx);
}
})
.into_any_element()
})),
),
),
@@ -2655,7 +2544,7 @@ impl ActiveThread {
)
} else {
v_flex()
.my_2()
.my_3()
.rounded_lg()
.border_1()
.border_color(self.tool_card_border_color(cx))
@@ -2872,7 +2761,7 @@ impl ActiveThread {
)
})
}
}).into_any_element()
})
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
@@ -3064,12 +2953,6 @@ impl ActiveThread {
}
}
pub enum ActiveThreadEvent {
EditingMessageTokenCountChanged,
}
impl EventEmitter<ActiveThreadEvent> for ActiveThread {}
impl Render for ActiveThread {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()

View File

@@ -1,7 +1,7 @@
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
use collections::HashSet;
use editor::{
Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
actions::{GoToHunk, GoToPreviousHunk},
@@ -355,24 +355,16 @@ impl AgentDiff {
self.update_selection(&diff_hunks_in_ranges, window, cx);
}
let mut ranges_by_buffer = HashMap::default();
for hunk in &diff_hunks_in_ranges {
let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer {
ranges_by_buffer
.entry(buffer.clone())
.or_insert_with(Vec::new)
.push(hunk.buffer_range.clone());
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
})
.detach_and_log_err(cx);
}
}
for (buffer, ranges) in ranges_by_buffer {
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_ranges(buffer, ranges, cx)
})
.detach_and_log_err(cx);
}
}
fn update_selection(

View File

@@ -9,14 +9,11 @@ use assistant_tool::{ToolSource, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use fs::Fs;
use gpui::{
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
};
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
Switch, Tooltip, prelude::*,
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -34,8 +31,6 @@ pub struct AssistantConfiguration {
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
scrollbar_state: ScrollbarState,
}
impl AssistantConfiguration {
@@ -65,9 +60,6 @@ impl AssistantConfiguration {
},
);
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
let mut this = Self {
fs,
focus_handle,
@@ -76,8 +68,6 @@ impl AssistantConfiguration {
expanded_context_server_tools: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
scroll_handle,
scrollbar_state,
};
this.build_provider_configuration_views(window, cx);
this
@@ -119,7 +109,7 @@ pub enum AssistantConfigurationEvent {
impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
impl AssistantConfiguration {
fn render_provider_configuration_block(
fn render_provider_configuration(
&mut self,
provider: &Arc<dyn LanguageModelProvider>,
cx: &mut Context<Self>,
@@ -174,7 +164,7 @@ impl AssistantConfiguration {
.p(DynamicSpacing::Base08.rems(cx))
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(cx.theme().colors().border)
.border_color(cx.theme().colors().border_variant)
.rounded_sm()
.map(|parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
@@ -185,33 +175,6 @@ impl AssistantConfiguration {
)
}
fn render_provider_configuration_section(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_4()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
.child(
Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted),
),
)
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
)
}
fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
@@ -219,7 +182,6 @@ impl AssistantConfiguration {
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.flex_1()
.child(Headline::new("General Settings").size(HeadlineSize::Small))
@@ -271,7 +233,6 @@ impl AssistantConfiguration {
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.flex_1()
.child(
@@ -465,51 +426,39 @@ impl AssistantConfiguration {
impl Render for AssistantConfiguration {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.id("assistant-configuration")
.key_context("AgentConfiguration")
.track_focus(&self.focus_handle(cx))
.relative()
.size_full()
.pb_8()
.bg(cx.theme().colors().panel_background)
.size_full()
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(
v_flex()
.id("assistant-configuration-content")
.track_scroll(&self.scroll_handle)
.size_full()
.overflow_y_scroll()
.child(self.render_command_permission(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_provider_configuration_section(cx)),
)
.child(
div()
.id("assistant-configuration-scrollbar")
.occlude()
.absolute()
.right(px(3.))
.top_0()
.bottom_0()
.pb_6()
.w(px(12.))
.cursor_default()
.on_mouse_move(cx.listener(|_, _, _window, cx| {
cx.notify();
cx.stop_propagation()
}))
.on_hover(|_, _window, cx| {
cx.stop_propagation();
})
.on_any_mouse_down(|_, _window, cx| {
cx.stop_propagation();
})
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
cx.notify();
}))
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
.p(DynamicSpacing::Base16.rems(cx))
.mt_1()
.gap_6()
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
.child(
Label::new("Add at least one provider to use AI-powered features.")
.color(Color::Muted),
),
)
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration(&provider, cx)),
),
)
}
}

View File

@@ -5,7 +5,7 @@ use std::time::Duration;
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider,
humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens,
make_lsp_adapter_delegate, render_remaining_tokens,
};
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
@@ -25,7 +25,6 @@ use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::PromptBuilder;
use proto::Plan;
use settings::{Settings, update_settings_file};
use time::UtcOffset;
use ui::{
@@ -37,10 +36,10 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::message_editor::MessageEditor;
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
@@ -181,8 +180,8 @@ pub struct AssistantPanel {
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
thread: Entity<ActiveThread>,
_thread_subscription: Subscription,
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
context_editor: Option<Entity<ContextEditor>>,
configuration: Option<Entity<AssistantConfiguration>>,
@@ -264,13 +263,6 @@ impl AssistantPanel {
)
});
let message_editor_subscription =
cx.subscribe(&message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
@@ -295,12 +287,6 @@ impl AssistantPanel {
)
});
let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
Self {
active_view,
workspace,
@@ -309,12 +295,8 @@ impl AssistantPanel {
language_registry,
thread_store: thread_store.clone(),
thread,
_thread_subscription: thread_subscription,
message_editor,
_active_thread_subscriptions: vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
],
context_store,
context_editor: None,
configuration: None,
@@ -399,13 +381,6 @@ impl AssistantPanel {
.detach_and_log_err(cx);
}
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -418,12 +393,12 @@ impl AssistantPanel {
)
});
let active_thread_subscription =
cx.subscribe(&self.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
self._thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
self.message_editor = cx.new(|cx| {
MessageEditor::new(
@@ -437,19 +412,6 @@ impl AssistantPanel {
)
});
self.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&self.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
self._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -575,13 +537,6 @@ impl AssistantPanel {
Some(this.thread_store.downgrade()),
)
});
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
@@ -593,14 +548,6 @@ impl AssistantPanel {
cx,
)
});
let active_thread_subscription =
cx.subscribe(&this.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
this.message_editor = cx.new(|cx| {
MessageEditor::new(
this.fs.clone(),
@@ -613,19 +560,6 @@ impl AssistantPanel {
)
});
this.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&this.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
this._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
})
})
}
@@ -918,7 +852,7 @@ impl Panel for AssistantPanel {
}
impl AssistantPanel {
fn render_title_view(&self, _window: &mut Window, cx: &Context<Self>) -> AnyElement {
fn render_title_view(&self, _window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…";
let content = match &self.active_view {
@@ -978,8 +912,13 @@ impl AssistantPanel {
fn render_toolbar(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_thread = self.thread.read(cx);
let thread = active_thread.thread().read(cx);
let token_usage = thread.total_token_usage(cx);
let thread_id = thread.id().clone();
let is_generating = thread.is_generating();
let is_empty = active_thread.is_empty();
let focus_handle = self.focus_handle(cx);
let is_history = matches!(self.active_view, ActiveView::History);
let show_token_count = match &self.active_view {
@@ -988,8 +927,6 @@ impl AssistantPanel {
_ => false,
};
let focus_handle = self.focus_handle(cx);
let go_back_button = match &self.active_view {
ActiveView::History | ActiveView::Configuration => Some(
div().pl_1().child(
@@ -1036,9 +973,69 @@ impl AssistantPanel {
h_flex()
.h_full()
.gap_2()
.when(show_token_count, |parent|
parent.children(self.render_token_count(&thread, cx))
)
.when(show_token_count, |parent| match self.active_view {
ActiveView::Thread { .. } => {
if token_usage.total == 0 {
return parent;
}
let token_color = match token_usage.ratio {
TokenUsageRatio::Normal => Color::Muted,
TokenUsageRatio::Warning => Color::Warning,
TokenUsageRatio::Exceeded => Color::Error,
};
parent.child(
h_flex()
.flex_shrink_0()
.gap_0p5()
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.total,
))
.size(LabelSize::Small)
.color(token_color)
.map(|label| {
if is_generating {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(
0.6, 1.,
)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(
Label::new("/").size(LabelSize::Small).color(Color::Muted),
)
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.max,
))
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}
ActiveView::PromptEditor => {
let Some(editor) = self.context_editor.as_ref() else {
return parent;
};
let Some(element) = render_remaining_tokens(editor, cx) else {
return parent;
};
parent.child(element)
}
_ => parent,
})
.child(
h_flex()
.h_full()
@@ -1115,16 +1112,16 @@ impl AssistantPanel {
"New Text Thread",
NewTextThread.boxed_clone(),
)
.action("Prompt Library", Box::new(OpenPromptLibrary))
.action("Settings", Box::new(OpenConfiguration))
.action("Settings", OpenConfiguration.boxed_clone())
.separator()
.action(
"Install MCPs",
Box::new(zed_actions::Extensions {
zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
}),
}
.boxed_clone(),
)
},
))
@@ -1134,111 +1131,6 @@ impl AssistantPanel {
)
}
fn render_token_count(&self, thread: &Thread, cx: &App) -> Option<AnyElement> {
let is_generating = thread.is_generating();
let message_editor = self.message_editor.read(cx);
let conversation_token_usage = thread.total_token_usage(cx);
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
self.thread.read(cx).editing_message_id()
{
let combined = thread
.token_usage_up_to_message(editing_message_id, cx)
.add(unsent_tokens);
(combined, unsent_tokens > 0)
} else {
let unsent_tokens = message_editor.last_estimated_token_count().unwrap_or(0);
let combined = conversation_token_usage.add(unsent_tokens);
(combined, unsent_tokens > 0)
};
let is_waiting_to_update_token_count = message_editor.is_waiting_to_update_token_count();
match self.active_view {
ActiveView::Thread { .. } => {
if total_token_usage.total == 0 {
return None;
}
let token_color = match total_token_usage.ratio() {
TokenUsageRatio::Normal if is_estimating => Color::Default,
TokenUsageRatio::Normal => Color::Muted,
TokenUsageRatio::Warning => Color::Warning,
TokenUsageRatio::Exceeded => Color::Error,
};
let token_count = h_flex()
.id("token-count")
.flex_shrink_0()
.gap_0p5()
.when(!is_generating && is_estimating, |parent| {
parent
.child(
h_flex()
.mr_0p5()
.size_2()
.justify_center()
.rounded_full()
.bg(cx.theme().colors().text.opacity(0.1))
.child(
div().size_1().rounded_full().bg(cx.theme().colors().text),
),
)
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Estimated New Token Count",
None,
format!(
"Current Conversation Tokens: {}",
humanize_token_count(conversation_token_usage.total)
),
window,
cx,
)
})
})
.child(
Label::new(humanize_token_count(total_token_usage.total))
.size(LabelSize::Small)
.color(token_color)
.map(|label| {
if is_generating || is_waiting_to_update_token_count {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
.child(
Label::new(humanize_token_count(total_token_usage.max))
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any();
Some(token_count)
}
ActiveView::PromptEditor => {
let editor = self.context_editor.as_ref()?;
let element = render_remaining_tokens(editor, cx)?;
Some(element.into_any_element())
}
_ => None,
}
}
fn render_active_thread_or_empty_state(
&self,
window: &mut Window,
@@ -1557,9 +1449,6 @@ impl AssistantPanel {
ThreadError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
ThreadError::ModelRequestLimitReached { plan } => {
self.render_model_request_limit_reached_error(plan, cx)
}
ThreadError::Message { header, message } => {
self.render_error_message(header, message, cx)
}
@@ -1662,67 +1551,6 @@ impl AssistantPanel {
.into_any()
}
fn render_model_request_limit_reached_error(
&self,
plan: Plan,
cx: &mut Context<Self>,
) -> AnyElement {
let error_message = match plan {
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
Plan::ZedPro => {
"Model request limit reached. Upgrade to usage-based billing for more requests."
}
};
let call_to_action = match plan {
Plan::Free => "Upgrade to Zed Pro",
Plan::ZedPro => "Upgrade to usage-based billing",
};
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Model Request Limit Reached").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(error_message)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(
Button::new("subscribe", call_to_action).on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
},
)),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
},
))),
)
.into_any()
}
fn render_error_message(
&self,
header: SharedString,
@@ -1895,27 +1723,10 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
fn quote_selection(
&self,
workspace: &mut Workspace,
creases: Vec<(String, String)>,
window: &mut Window,
cx: &mut Context<Workspace>,
_workspace: &mut Workspace,
_creases: Vec<(String, String)>,
_window: &mut Window,
_cx: &mut Context<Workspace>,
) {
let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
return;
};
if !panel.focus_handle(cx).contains_focused(window, cx) {
workspace.toggle_panel_focus::<AssistantPanel>(window, cx);
}
panel.update(cx, |_, cx| {
// Wait to create a new context until the workspace is no longer
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if let Some(context) = panel.active_context_editor() {
context.update(cx, |context, cx| context.quote_creases(creases, window, cx));
};
});
});
}
}

View File

@@ -8,7 +8,6 @@ use std::sync::atomic::AtomicBool;
use anyhow::Result;
use editor::{CompletionProvider, Editor, ExcerptId};
use file_icons::FileIcons;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{App, Entity, Task, WeakEntity};
use http_client::HttpClientWithUrl;
use language::{Buffer, CodeLabel, HighlightId};
@@ -38,24 +37,7 @@ pub(crate) enum Match {
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Mode(ModeMatch),
}
pub struct ModeMatch {
mat: Option<StringMatch>,
mode: ContextPickerMode,
}
impl Match {
pub fn score(&self) -> f64 {
match self {
Match::File(file) => file.mat.score,
Match::Mode(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
Match::Thread(_) => 1.,
Match::Symbol(_) => 1.,
Match::Fetch(_) => 1.,
}
}
Mode(ContextPickerMode),
}
fn search(
@@ -144,54 +126,19 @@ fn search(
matches.extend(
supported_context_picker_modes(&thread_store)
.into_iter()
.map(|mode| Match::Mode(ModeMatch { mode, mat: None })),
.map(Match::Mode),
);
Task::ready(matches)
} else {
let executor = cx.background_executor().clone();
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let modes = supported_context_picker_modes(&thread_store);
let mode_candidates = modes
.iter()
.enumerate()
.map(|(ix, mode)| StringMatchCandidate::new(ix, mode.mention_prefix()))
.collect::<Vec<_>>();
cx.background_spawn(async move {
let mut matches = search_files_task
search_files_task
.await
.into_iter()
.map(Match::File)
.collect::<Vec<_>>();
let mode_matches = fuzzy::match_strings(
&mode_candidates,
&query,
false,
100,
&Arc::new(AtomicBool::default()),
executor,
)
.await;
matches.extend(mode_matches.into_iter().map(|mat| {
Match::Mode(ModeMatch {
mode: modes[mat.candidate_id],
mat: Some(mat),
})
}));
matches.sort_by(|a, b| {
b.score()
.partial_cmp(&a.score())
.unwrap_or(std::cmp::Ordering::Equal)
});
matches
.collect()
})
}
}
@@ -601,7 +548,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Mode(ModeMatch { mode, .. }) => {
Match::Mode(mode) => {
Some(Self::completion_for_mode(source_range.clone(), mode))
}
})

View File

@@ -2,23 +2,22 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use crate::context::format_context_as_string;
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
EditorStyle, MultiBuffer,
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode, EditorStyle,
MultiBuffer,
};
use file_icons::FileIcons;
use fs::Fs;
use gpui::{
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
Animation, AnimationExt, App, Entity, Focusable, Subscription, TextStyle, WeakEntity,
linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@@ -56,8 +55,6 @@ pub struct MessageEditor {
edits_expanded: bool,
editor_is_expanded: bool,
waiting_for_summaries_to_send: bool,
last_estimated_token_count: Option<usize>,
update_token_count_task: Option<Task<anyhow::Result<()>>>,
_subscriptions: Vec<Subscription>,
}
@@ -132,18 +129,8 @@ impl MessageEditor {
let incompatible_tools =
cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx));
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.message_or_context_changed(true, cx);
}
_ => {}
}),
cx.observe(&context_store, |this, _, cx| {
this.message_or_context_changed(false, cx);
}),
];
let subscriptions =
vec![cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event)];
Self {
editor: editor.clone(),
@@ -169,8 +156,6 @@ impl MessageEditor {
waiting_for_summaries_to_send: false,
profile_selector: cx
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
last_estimated_token_count: None,
update_token_count_task: None,
_subscriptions: subscriptions,
}
}
@@ -271,9 +256,6 @@ impl MessageEditor {
text
});
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
@@ -955,80 +937,6 @@ impl MessageEditor {
.label_size(LabelSize::Small),
)
}
pub fn last_estimated_token_count(&self) -> Option<usize> {
self.last_estimated_token_count
}
pub fn is_waiting_to_update_token_count(&self) -> bool {
self.update_token_count_task.is_some()
}
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
self.last_estimated_token_count.take();
return;
};
let context_store = self.context_store.clone();
let editor = self.editor.clone();
let thread = self.thread.clone();
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
cx.background_executor()
.timer(Duration::from_millis(200))
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = context_store.read(cx).context().iter();
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
return None;
}
let request = language_model::LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
tools: vec![],
stop: vec![],
temperature: None,
};
Some(default_model.model.count_tokens(request, cx))
})? {
task.await?
} else {
0
};
this.update(cx, |this, cx| {
this.last_estimated_token_count = Some(token_count);
cx.emit(MessageEditorEvent::EstimatedTokenCount);
this.update_token_count_task.take();
})
}));
}
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
pub enum MessageEditorEvent {
EstimatedTokenCount,
Changed,
}
impl Focusable for MessageEditor {
@@ -1041,7 +949,6 @@ impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
let total_token_usage = thread.total_token_usage(cx);
let token_usage_ratio = total_token_usage.ratio();
let action_log = self.thread.read(cx).action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
@@ -1090,8 +997,15 @@ impl Render for MessageEditor {
parent.child(self.render_changed_buffers(&changed_buffers, window, cx))
})
.child(self.render_editor(font_size, line_height, window, cx))
.when(token_usage_ratio != TokenUsageRatio::Normal, |parent| {
parent.child(self.render_token_limit_callout(line_height, token_usage_ratio, cx))
})
.when(
total_token_usage.ratio != TokenUsageRatio::Normal,
|parent| {
parent.child(self.render_token_limit_callout(
line_height,
total_token_usage.ratio,
cx,
))
},
)
}
}

View File

@@ -6,7 +6,7 @@ use std::time::Instant;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use feature_flags::{self, FeatureFlagAppExt};
@@ -18,13 +18,12 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use prompt_store::PromptBuilder;
use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
@@ -227,33 +226,7 @@ pub enum DetailedSummaryState {
pub struct TotalTokenUsage {
pub total: usize,
pub max: usize,
}
impl TotalTokenUsage {
pub fn ratio(&self) -> TokenUsageRatio {
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
if self.total >= self.max {
TokenUsageRatio::Exceeded
} else if self.total as f32 / self.max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
}
}
pub fn add(&self, tokens: usize) -> TotalTokenUsage {
TotalTokenUsage {
total: self.total + tokens,
max: self.max,
}
}
pub ratio: TokenUsageRatio,
}
#[derive(Debug, Default, PartialEq, Eq)]
@@ -287,12 +260,14 @@ pub struct Thread {
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
request_token_usage: Vec<TokenUsage>,
cumulative_token_usage: TokenUsage,
exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -338,12 +313,12 @@ impl Thread {
.spawn(async move { Some(project_snapshot.await) })
.shared()
},
request_token_usage: Vec::new(),
cumulative_token_usage: TokenUsage::default(),
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
}
}
@@ -406,15 +381,23 @@ impl Thread {
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
request_token_usage: serialized.request_token_usage,
cumulative_token_usage: serialized.cumulative_token_usage,
exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
}
}
pub fn set_request_callback(
&mut self,
callback: impl 'static
+ FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>]),
) {
self.request_callback = Some(Box::new(callback));
}
pub fn id(&self) -> &ThreadId {
&self.id
}
@@ -660,30 +643,10 @@ impl Thread {
self.tool_use.tool_result(id)
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
Some(&self.tool_use.tool_result(id)?.content)
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
self.tool_use.tool_result_card(id).cloned()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id)
}
/// Filter out contexts that have already been included in previous messages
pub fn filter_new_context<'a>(
&self,
context: impl Iterator<Item = &'a AssistantContext>,
) -> impl Iterator<Item = &'a AssistantContext> {
context.filter(|ctx| self.is_context_new(ctx))
}
fn is_context_new(&self, context: &AssistantContext) -> bool {
!self.context.contains_key(&context.id())
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
@@ -695,9 +658,10 @@ impl Thread {
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
// Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| self.is_context_new(ctx))
.filter(|ctx| !self.context.contains_key(&ctx.id()))
.collect();
if !new_context.is_empty() {
@@ -877,7 +841,6 @@ impl Thread {
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage,
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
@@ -1063,16 +1026,28 @@ impl Thread {
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
let request_callback_parameters = if self.request_callback.is_some() {
Some((request.clone(), Vec::new()))
} else {
None
};
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let mut request_callback_parameters = request_callback_parameters;
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
while let Some(event) = events.next().await {
if let Some((_, response_events)) = request_callback_parameters.as_mut() {
response_events
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
}
let event = event?;
thread.update(cx, |thread, cx| {
@@ -1088,7 +1063,6 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.update_token_usage_at_last_message(token_usage);
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
@@ -1176,7 +1150,7 @@ impl Thread {
}
})?;
anyhow::Ok(stop_reason)
anyhow::Ok((stop_reason, request_callback_parameters))
};
let result = stream_completion.await;
@@ -1185,14 +1159,24 @@ impl Thread {
.update(cx, |thread, cx| {
thread.finalize_pending_checkpoint(cx);
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
Ok((stop_reason, request_callback_parameters)) => {
match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
},
if let Some((request_callback, (request, response_events))) = thread
.request_callback
.as_mut()
.zip(request_callback_parameters.as_ref())
{
request_callback(request, response_events);
}
}
Err(error) => {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
@@ -1200,12 +1184,6 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else if let Some(error) =
error.downcast_ref::<ModelRequestLimitReachedError>()
{
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
} else if let Some(known_error) =
error.downcast_ref::<LanguageModelKnownError>()
{
@@ -1235,7 +1213,9 @@ impl Thread {
thread.cancel_last_completion(cx);
}
}
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
cx.emit(ThreadEvent::Stopped(
result.map(|result| result.0).map_err(Arc::new),
));
thread.auto_capture_telemetry(cx);
@@ -1475,12 +1455,6 @@ impl Thread {
)
};
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
self.tool_use
.insert_tool_result_card(tool_use_id.clone(), card);
}
cx.spawn({
async move |thread: WeakEntity<Thread>, cx| {
let output = tool_result.output.await;
@@ -1863,14 +1837,14 @@ impl Thread {
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
pub fn reject_edits_in_ranges(
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<language::Buffer>,
buffer_ranges: Vec<Range<language::Anchor>>,
buffer_range: Range<language::Anchor>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.action_log.update(cx, |action_log, cx| {
action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
action_log.reject_edits_in_range(buffer, buffer_range, cx)
})
}
@@ -1930,35 +1904,6 @@ impl Thread {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return TotalTokenUsage::default();
};
let max = model.model.max_token_count();
let index = self
.messages
.iter()
.position(|msg| msg.id == message_id)
.unwrap_or(0);
if index == 0 {
return TotalTokenUsage { total: 0, max };
}
let token_usage = &self
.request_token_usage
.get(index - 1)
.cloned()
.unwrap_or_default();
TotalTokenUsage {
total: token_usage.total_tokens() as usize,
max,
}
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
@@ -1972,33 +1917,30 @@ impl Thread {
return TotalTokenUsage {
total: exceeded_error.token_count,
max,
ratio: TokenUsageRatio::Exceeded,
};
}
}
let total = self
.token_usage_at_last_message()
.unwrap_or_default()
.total_tokens() as usize;
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.8;
TotalTokenUsage { total, max }
}
let total = self.cumulative_token_usage.total_tokens() as usize;
fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
self.request_token_usage
.get(self.messages.len().saturating_sub(1))
.or_else(|| self.request_token_usage.last())
.cloned()
}
let ratio = if total >= max {
TokenUsageRatio::Exceeded
} else if total as f32 / max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
};
fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
let placeholder = self.token_usage_at_last_message().unwrap_or_default();
self.request_token_usage
.resize(self.messages.len(), placeholder);
if let Some(last) = self.request_token_usage.last_mut() {
*last = token_usage;
}
TotalTokenUsage { total, max, ratio }
}
pub fn deny_tool_use(
@@ -2023,8 +1965,6 @@ pub enum ThreadError {
PaymentRequired,
#[error("Max monthly spend reached")]
MaxMonthlySpendReached,
#[error("Model request limit reached")]
ModelRequestLimitReached { plan: Plan },
#[error("Message {header}: {message}")]
Message {
header: SharedString,

View File

@@ -509,8 +509,6 @@ pub struct SerializedThread {
#[serde(default)]
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
pub request_token_usage: Vec<TokenUsage>,
#[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
@@ -599,7 +597,6 @@ impl LegacySerializedThread {
messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
}

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
@@ -27,7 +27,26 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
@@ -35,9 +54,10 @@ pub struct ToolUseState {
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
@@ -46,7 +66,6 @@ impl ToolUseState {
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
}
}
@@ -238,18 +257,6 @@ impl ToolUseState {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,

View File

@@ -191,12 +191,15 @@ impl RenderOnce for ContextPill {
ContextPill::Suggested {
name,
icon_path: _,
kind: _,
kind,
focused,
on_click,
} => base_pill
.cursor_pointer()
.pr_1()
.when(*focused, |this| {
this.bg(color.element_background.opacity(0.5))
})
.border_dashed()
.border_color(if *focused {
color.border_focused
@@ -204,17 +207,30 @@ impl RenderOnce for ContextPill {
color.border
})
.hover(|style| style.bg(color.element_hover.opacity(0.5)))
.when(*focused, |this| {
this.bg(color.element_background.opacity(0.5))
})
.child(
div().max_w_64().child(
div().px_0p5().max_w_64().child(
Label::new(name.clone())
.size(LabelSize::Small)
.color(Color::Muted)
.truncate(),
),
)
.child(
Label::new(match kind {
ContextKind::File => "Active Tab",
ContextKind::Thread
| ContextKind::Directory
| ContextKind::FetchedUrl
| ContextKind::Symbol => "Active",
})
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.child(
Icon::new(IconName::Plus)
.size(IconSize::XSmall)
.into_any_element(),
)
.tooltip(|window, cx| {
Tooltip::with_meta("Suggested Context", None, "Click to add it", window, cx)
})

View File

@@ -3,7 +3,7 @@ use buffer_diff::BufferDiff;
use collections::BTreeMap;
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
@@ -363,10 +363,10 @@ impl ActionLog {
}
}
pub fn reject_edits_in_ranges(
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<Buffer>,
buffer_ranges: Vec<Range<impl language::ToPoint>>,
buffer_range: Range<impl language::ToPoint>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
@@ -403,15 +403,29 @@ impl ActionLog {
}
TrackedBufferStatus::Modified => {
buffer.update(cx, |buffer, cx| {
let mut buffer_row_ranges = buffer_ranges
.into_iter()
.map(|range| {
range.start.to_point(buffer).row..range.end.to_point(buffer).row
})
.peekable();
let buffer_range =
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
let mut edits_to_revert = Vec::new();
for edit in tracked_buffer.unreviewed_changes.edits() {
if buffer_range.end.row < edit.new.start {
break;
} else if buffer_range.start.row > edit.new.end {
continue;
}
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
let new_range = tracked_buffer
.snapshot
.anchor_before(Point::new(edit.new.start, 0))
@@ -419,35 +433,7 @@ impl ActionLog {
Point::new(edit.new.end, 0),
tracked_buffer.snapshot.max_point(),
));
let new_row_range = new_range.start.to_point(buffer).row
..new_range.end.to_point(buffer).row;
let mut revert = false;
while let Some(buffer_row_range) = buffer_row_ranges.peek() {
if buffer_row_range.end < new_row_range.start {
buffer_row_ranges.next();
} else if buffer_row_range.start > new_row_range.end {
break;
} else {
revert = true;
break;
}
}
if revert {
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
edits_to_revert.push((new_range, old_text));
}
edits_to_revert.push((new_range, old_text));
}
buffer.edit(edits_to_revert, None, cx);
@@ -613,7 +599,6 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
}
}
#[derive(Copy, Clone, Debug)]
enum ChangeAuthor {
User,
Agent,
@@ -1150,48 +1135,9 @@ mod tests {
)]
);
// If the rejected range doesn't overlap with any hunk, we ignore it.
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(1, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
})
.await
.unwrap();
@@ -1214,11 +1160,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(4, 0)..Point::new(4, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
})
.await
.unwrap();
@@ -1230,82 +1172,6 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_multiple_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
.unwrap()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log.update(cx, |log, cx| {
let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0))
..buffer.read(cx).anchor_before(Point::new(1, 0));
let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0))
..buffer.read(cx).anchor_before(Point::new(5, 3));
log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx)
.detach();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_deleted_file(cx: &mut TestAppContext) {
init_test(cx);
@@ -1349,11 +1215,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 0)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
})
.await
.unwrap();
@@ -1404,11 +1266,7 @@ mod tests {
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(0, 11)],
cx,
)
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
})
.await
.unwrap();
@@ -1454,7 +1312,7 @@ mod tests {
.update(cx, |log, cx| {
let range = buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("rejecting edits in range {:?}", range);
log.reject_edits_in_ranges(buffer.clone(), vec![range], cx)
log.reject_edits_in_range(buffer.clone(), range, cx)
})
.await
.unwrap();

View File

@@ -9,10 +9,6 @@ use std::fmt::Formatter;
use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
@@ -28,87 +24,16 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
Error(SharedString),
}
impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
ToolUseStatus::Error(out) => out.clone(),
}
}
}
/// The result of running a tool, containing both the asynchronous output
/// and an optional card view that can be rendered immediately.
/// The result of running a tool
pub struct ToolResult {
/// The asynchronous task that will eventually resolve to the tool's output
pub output: Task<Result<String>>,
/// An optional view to present the output of the tool.
pub card: Option<AnyToolCard>,
}
pub trait ToolCard: 'static + Sized {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
#[derive(Clone)]
pub struct AnyToolCard {
entity: gpui::AnyEntity,
render: fn(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement,
}
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
fn from(entity: Entity<T>) -> Self {
fn downcast_render<T: ToolCard>(
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity.render(status, window, cx).into_any_element()
})
}
Self {
entity: entity.into(),
render: downcast_render::<T>,
}
}
}
impl AnyToolCard {
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
}
}
impl From<Task<Result<String>>> for ToolResult {
/// Convert from a task to a ToolResult with no card
/// Convert from a task to a ToolResult
fn from(output: Task<Result<String>>) -> Self {
Self { output, card: None }
Self { output }
}
}

View File

@@ -16,7 +16,6 @@ anyhow.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
collections.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@@ -33,9 +32,7 @@ ui.workspace = true
util.workspace = true
worktree.workspace = true
open = { workspace = true }
web_search.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
@@ -43,5 +40,6 @@ gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
settings = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -7,8 +7,8 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod fetch_tool;
mod find_replace_file_tool;
mod list_directory_tool;
mod move_path_tool;
mod now_tool;
@@ -22,17 +22,14 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
@@ -42,8 +39,8 @@ use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_replace_file_tool::FindReplaceFileTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
@@ -59,39 +56,28 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CodeActionTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(CopyPathTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(FetchTool::new(http_client));
registry.register_tool(FindReplaceFileTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(TerminalTool);
registry.register_tool(ThinkingTool);
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
}
})
.detach();
registry.register_tool(FetchTool::new(http_client));
}
#[cfg(test)]

View File

@@ -0,0 +1,183 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The text to replace.
pub old_string: String,
/// The text to replace it with.
pub new_string: String,
}
pub struct EditFileTool;
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<EditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
}
if input.old_string == input.new_string {
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
if diff.edits.is_empty() {
return None;
}
let old_text = snapshot.text();
Some((old_text, diff))
})
.await;
let Some((old_text, diff)) = result else {
let err = buffer.read_with(cx, |buffer, _cx| {
let file_exists = buffer
.file()
.map_or(false, |file| file.disk_state().exists());
if !file_exists {
anyhow!("{} does not exist", input.path.display())
} else if buffer.is_empty() {
anyhow!(
"{} is empty, so the provided `old_string` wasn't found.",
input.path.display()
)
} else {
anyhow!("Failed to match the provided `old_string`")
}
})?;
return Err(err)
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
snapshot
})?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}).into()
}
}

View File

@@ -0,0 +1,45 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location
To make a file edit, provide the following:
1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project.
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
3. new_string: The edited text, which will replace the old_string in the file.
The tool will replace ONE occurrence of old_string with new_string in the specified file.
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
- Include AT LEAST 3-5 lines of context BEFORE the change point
- Include AT LEAST 3-5 lines of context AFTER the change point
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
- Make separate calls to this tool for each instance
- Each call must uniquely identify its specific instance using extensive context
3. VERIFICATION: Before using this tool:
- Check how many instances of the target text exist in the file
- If multiple instances exist, gather enough context to uniquely identify each one
- Plan separate tool calls for each instance
WARNING: If you do not follow these requirements:
- The tool will fail if old_string matches multiple locations
- The tool will fail if old_string doesn't match exactly (including whitespace)
- You may change the wrong instance if you don't include enough context
When making edits:
- Ensure the edit results in idiomatic, correct code
- Do not leave the code in a broken state
- Always use fully-qualified project paths (starting with the name of one of the project's root directories)
If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`.
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.

View File

@@ -1,268 +0,0 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindReplaceFileToolInput {
/// The path of the file to modify.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The unique string to find in the file. This string cannot be empty;
/// if the string is empty, the tool call will fail. Remember, do not use this tool
/// to create new files from scratch, or to overwrite existing files! Use a different
/// approach if you want to do that.
///
/// If this string appears more than once in the file, this tool call will fail,
/// so it is absolutely critical that you verify ahead of time that the string
/// is unique. You can search within the file to verify this.
///
/// To make the string more likely to be unique, include a minimum of 3 lines of context
/// before the string you actually want to find, as well as a minimum of 3 lines of
/// context after the string you want to find. (These lines of context should appear
/// in the `replace` string as well.) If 3 lines of context is not enough to obtain
/// a string that appears only once in the file, then double the number of context lines
/// until the string becomes unique. (Start with 3 lines before and 3 lines after
/// though, because too much context is needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. This string must be exactly as
/// it appears in the file, because this tool will do a literal find/replace, and if
/// even one character in this string is different in any way from how it appears
/// in the file, then the tool call will fail.
///
/// If you get an error that the `find` string was not found, this means that either
/// you made a mistake, or that the file has changed since you last looked at it.
/// Either way, when this happens, you should retry doing this tool call until it
/// succeeds, up to 3 times. Each time you retry, you should take another look at
/// the exact text of the file in question, to make sure that you are searching for
/// exactly the right string. Regardless of whether it was because you made a mistake
/// or because the file changed since you last looked at it, you should be extra
/// careful when retrying in this way. It's a bad experience for the user if
/// this `find` string isn't found, so be super careful to get it exactly right!
///
/// <example>
/// If a file contains this code:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
///
/// Your find string should include at least 3 lines of context before and after the part
/// you want to change:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
///
/// And your replace string might look like:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" || user.role == "superuser" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
/// </example>
pub find: String,
/// The string to replace the one unique occurrence of the find string with.
pub replace: String,
}
pub struct FindReplaceFileTool;
impl Tool for FindReplaceFileTool {
fn name(&self) -> String {
"find_replace_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("find_replace_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindReplaceFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindReplaceFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.find.is_empty() {
return Err(anyhow!("`find` string cannot be empty. Use a different tool if you want to create a file."));
}
if input.find == input.replace {
return Err(anyhow!("The `find` and `replace` strings are identical, so no changes would be made."));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.find, &input.replace, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.find, &input.replace, &snapshot))?;
if diff.edits.is_empty() {
return None;
}
let old_text = snapshot.text();
Some((old_text, diff))
})
.await;
let Some((old_text, diff)) = result else {
let err = buffer.read_with(cx, |buffer, _cx| {
let file_exists = buffer
.file()
.map_or(false, |file| file.disk_state().exists());
if !file_exists {
anyhow!("{} does not exist", input.path.display())
} else if buffer.is_empty() {
anyhow!(
"{} is empty, so the provided `find` string wasn't found.",
input.path.display()
)
} else {
anyhow!("Failed to match the provided `find` string")
}
})?;
return Err(err)
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
snapshot
})?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}).into()
}
}

View File

@@ -12,7 +12,7 @@ use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ListDirectoryToolInput {
/// The relative path of the directory to list.
/// The fully-qualified path of the directory to list in the project.
///
/// This path should never be absolute, and the first component
/// of the path should always be a root directory in a project.

View File

@@ -1 +1 @@
Lists files and directories in a given path.
Lists files and directories in a given path. Prefer the `regex_search` or `path_search` tools when searching the codebase.

View File

@@ -6,14 +6,14 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
use ui::IconName;
use util::paths::PathMatcher;
use worktree::Snapshot;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct PathSearchToolInput {
/// The glob to search all project paths for.
/// The glob to match against every path in the project.
///
/// <example>
/// If the project has the following root directories:
@@ -76,66 +76,114 @@ impl Tool for PathSearchTool {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { &glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))).into(),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
let offset = offset as usize;
let task = search_paths(&glob, project, cx);
cx.background_spawn(async move {
let mut matches = Vec::new();
for worktree in snapshots {
let root_name = worktree.root_name();
// Don't consider ignored entries.
for entry in worktree.entries(false, 0) {
if path_matcher.is_match(&entry.path) {
matches.push(
PathBuf::from(root_name)
.join(&entry.path)
.to_string_lossy()
.to_string(),
);
}
}
}
if matches.is_empty() {
Ok(format!("No paths in the project matched the glob {glob:?}"))
} else {
// Sort to group entries in the same directory together.
matches.sort();
let total_matches = matches.len();
let response = if total_matches > RESULTS_PER_PAGE + offset as usize {
let paginated_matches: Vec<_> = matches
.into_iter()
.skip(offset as usize)
.take(RESULTS_PER_PAGE)
.collect();
format!(
"Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}",
total_matches,
offset + 1,
offset as usize + paginated_matches.len(),
paginated_matches.join("\n")
)
} else {
matches.join("\n")
};
Ok(response)
let matches = task.await?;
let paginated_matches = &matches[cmp::min(offset, matches.len())..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
let mut message = format!(
"Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n",
matches.len(),
offset + 1,
offset as usize + paginated_matches.len(),
);
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}).into()
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))).into(),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_path_search_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,3 +1,7 @@
Returns paths in the project which match the given glob.
Fast file pattern matching tool that works with any codebase size
Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
- Returns matching file paths sorted alphabetically
- Prefer the `regex_search` tool to this tool when searching for symbols unless you have specific information about paths.
- Use this tool when you need to find files by name patterns
- Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.

View File

@@ -1,7 +1,6 @@
Searches the entire project for the given regular expression.
Returns a list of paths that matched the query. For each path, it returns some excerpts of the matched text.
Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.
This tool is not aware of semantics and does not use any information from language servers, so it should only be used when no available semantic tool (e.g. one that uses language servers) could fit a particular use case instead.
- Prefer this tool when searching for files containing symbols in the project.
- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.)
- Use this tool when you need to find files containing specific patterns
- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages.

View File

@@ -1,213 +0,0 @@
use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt, TryFutureExt};
use gpui::{
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
pulsating_between,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use zed_llm_client::WebSearchResponse;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct WebSearchToolInput {
/// The search term or question to query on the web.
query: String,
}
pub struct WebSearchTool;
impl Tool for WebSearchTool {
fn name(&self) -> String {
"web_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
}
fn icon(&self) -> IconName {
IconName::Globe
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<WebSearchToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Web Search".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
return Task::ready(Err(anyhow!("Web search is not available."))).into();
};
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
let output = cx.background_spawn({
let search_task = search_task.clone();
async move {
let response = search_task.await.map_err(|err| anyhow!(err))?;
serde_json::to_string(&response).context("Failed to serialize search results")
}
});
ToolResult {
output,
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
}
}
}
struct WebSearchToolCard {
response: Option<Result<WebSearchResponse>>,
_task: Task<()>,
}
impl WebSearchToolCard {
fn new(
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
cx: &mut Context<Self>,
) -> Self {
let _task = cx.spawn(async move |this, cx| {
let response = search_task.await.map_err(|err| anyhow!(err));
this.update(cx, |this, cx| {
this.response = Some(response);
cx.notify();
})
.ok();
});
Self {
response: None,
_task,
}
}
}
impl ToolCard for WebSearchToolCard {
fn render(
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(IconName::Globe)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(match self.response.as_ref() {
Some(Ok(response)) => {
let text: SharedString = if response.citations.len() == 1 {
"1 result".into()
} else {
format!("{} results", response.citations.len()).into()
};
h_flex()
.gap_1p5()
.child(Label::new("Searched the Web").size(LabelSize::Small))
.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(Label::new(text).size(LabelSize::Small))
.into_any_element()
}
Some(Err(error)) => div()
.id("web-search-error")
.child(Label::new("Web Search failed").size(LabelSize::Small))
.tooltip(Tooltip::text(error.to_string()))
.into_any_element(),
None => Label::new("Searching the Web…")
.size(LabelSize::Small)
.with_animation(
"web-search-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element(),
})
.into_any();
let content =
self.response.as_ref().and_then(|response| match response {
Ok(response) => {
Some(
v_flex()
.ml_1p5()
.pl_1p5()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
.children(response.citations.iter().enumerate().map(
|(index, citation)| {
let title = citation.title.clone();
let url = citation.url.clone();
Button::new(("citation", index), title)
.label_size(LabelSize::Small)
.color(Color::Muted)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.truncate(true)
.tooltip({
let url = url.clone();
move |window, cx| {
Tooltip::with_meta(
"Citation Link",
None,
url.clone(),
window,
cx,
)
}
})
.on_click({
let url = url.clone();
move |_, _, cx| cx.open_url(&url)
})
},
))
.into_any(),
)
}
Err(_) => None,
});
v_flex().my_2().gap_1().child(header).children(content)
}
}

View File

@@ -1,10 +0,0 @@
create table subscription_usages (
id serial primary key,
user_id integer not null,
period_start_at timestamp without time zone not null,
period_end_at timestamp without time zone not null,
model_requests int not null default 0,
edit_predictions int not null default 0
);
create unique index uix_subscription_usages_on_user_id_start_at_end_at on subscription_usages (user_id, period_start_at, period_end_at);

View File

@@ -15,12 +15,10 @@ use stripe::{
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
use util::ResultExt;
use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{
@@ -54,7 +52,6 @@ pub fn router() -> Router {
post(manage_billing_subscription),
)
.route("/billing/monthly_spend", get(get_monthly_spend))
.route("/billing/usage", get(get_current_usage))
}
#[derive(Debug, Deserialize)]
@@ -162,7 +159,6 @@ struct BillingSubscriptionJson {
id: BillingSubscriptionId,
name: String,
status: StripeSubscriptionStatus,
trial_end_at: Option<String>,
cancel_at: Option<String>,
/// Whether this subscription can be canceled.
is_cancelable: bool,
@@ -192,21 +188,9 @@ async fn list_billing_subscriptions(
id: subscription.id,
name: match subscription.kind {
Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
None => "Zed LLM Usage".to_string(),
},
status: subscription.stripe_subscription_status,
trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
maybe!({
let end_at = subscription.stripe_current_period_end?;
let end_at = DateTime::from_timestamp(end_at, 0)?;
Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
})
} else {
None
},
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at
.and_utc()
@@ -223,7 +207,6 @@ async fn list_billing_subscriptions(
#[serde(rename_all = "snake_case")]
enum ProductCode {
ZedPro,
ZedProTrial,
}
#[derive(Debug, Deserialize)]
@@ -303,36 +286,24 @@ async fn create_billing_subscription(
customer.id
};
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
let checkout_session_url = match body.product {
Some(ProductCode::ZedPro) => {
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
stripe_billing
.checkout_with_price(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.await?
}
Some(ProductCode::ZedProTrial) => {
stripe_billing
.checkout_with_price(
app.config.zed_pro_trial_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.await?
}
None => {
let default_model =
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
@@ -351,8 +322,6 @@ enum ManageSubscriptionIntent {
///
/// This will open the Stripe billing portal without putting the user in a specific flow.
ManageSubscription,
/// The user intends to upgrade to Zed Pro.
UpgradeToPro,
/// The user intends to cancel their subscription.
Cancel,
/// The user intends to stop the cancellation of their subscription.
@@ -404,10 +373,11 @@ async fn manage_billing_subscription(
.get_billing_subscription_by_id(body.subscription_id)
.await?
.ok_or_else(|| anyhow!("subscription not found"))?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
if body.intent == ManageSubscriptionIntent::StopCancellation {
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
let updated_stripe_subscription = Subscription::update(
&stripe_client,
&subscription_id,
@@ -440,47 +410,6 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = app.config.zed_pro_price_id()?;
let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
let zed_free_price_id = app.config.zed_free_price_id()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
let subscription_item_to_update = stripe_subscription
.items
.data
.iter()
.find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
Some(item.id.clone())
} else {
None
}
})
.ok_or_else(|| anyhow!("No subscription item to update"))?;
Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
subscription_update_confirm: Some(
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
subscription: subscription.stripe_subscription_id,
items: vec![
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
id: subscription_item_to_update.to_string(),
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
},
],
discounts: None,
},
),
..Default::default()
})
}
ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
@@ -767,25 +696,22 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let subscription_kind = maybe!({
let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
let zed_free_price_id = app.config.zed_free_price_id().ok()?;
let subscription_kind =
if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
let has_zed_pro_price = subscription.items.data.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id.as_str() == zed_pro_price_id)
});
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_pro_price_id {
if has_zed_pro_price {
Some(SubscriptionKind::ZedPro)
} else if price.id == zed_pro_trial_price_id {
Some(SubscriptionKind::ZedProTrial)
} else if price.id == zed_free_price_id {
Some(SubscriptionKind::ZedFree)
} else {
None
}
})
});
} else {
None
};
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
@@ -948,93 +874,6 @@ async fn get_monthly_spend(
}))
}
#[derive(Debug, Deserialize)]
struct GetCurrentUsageParams {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct UsageCounts {
pub used: i32,
pub limit: Option<i32>,
pub remaining: Option<i32>,
}
#[derive(Debug, Serialize)]
struct GetCurrentUsageResponse {
pub model_requests: UsageCounts,
pub edit_predictions: UsageCounts,
}
async fn get_current_usage(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetCurrentUsageParams>,
) -> Result<Json<GetCurrentUsageResponse>> {
let user = app
.db
.get_user_by_github_user_id(params.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some(llm_db) = app.llm_db.clone() else {
return Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"LLM database not available".into(),
));
};
let empty_usage = GetCurrentUsageResponse {
model_requests: UsageCounts {
used: 0,
limit: Some(0),
remaining: Some(0),
},
edit_predictions: UsageCounts {
used: 0,
limit: Some(0),
remaining: Some(0),
},
};
let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
return Ok(Json(empty_usage));
};
let subscription_period = maybe!({
let period_start_at = subscription.current_period_start_at()?;
let period_end_at = subscription.current_period_end_at()?;
Some((period_start_at, period_end_at))
});
let Some((period_start_at, period_end_at)) = subscription_period else {
return Ok(Json(empty_usage));
};
let usage = llm_db
.get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
.await?;
let Some(usage) = usage else {
return Ok(Json(empty_usage));
};
let model_requests_limit = Some(500);
let edit_prediction_limit = Some(2000);
Ok(Json(GetCurrentUsageResponse {
model_requests: UsageCounts {
used: usage.model_requests,
limit: model_requests_limit,
remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
},
edit_predictions: UsageCounts {
used: usage.edit_predictions,
limit: edit_prediction_limit,
remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
},
}))
}
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
fn from(value: SubscriptionStatus) -> Self {
match value {

View File

@@ -62,14 +62,11 @@ impl Database {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
kind: params.kind.clone(),
stripe_subscription_id: params.stripe_subscription_id.clone(),
stripe_subscription_status: params.stripe_subscription_status.clone(),
stripe_cancel_at: params.stripe_cancel_at.clone(),
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
stripe_current_period_start: params.stripe_current_period_start.clone(),
stripe_current_period_end: params.stripe_current_period_end.clone(),
created_at: ActiveValue::not_set(),
..Default::default()
})
.exec(&*tx)
.await?;
@@ -108,28 +105,6 @@ impl Database {
.await
}
pub async fn get_active_billing_subscription(
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.filter(
Condition::all()
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.add(billing_subscription::Column::Kind.is_not_null()),
)
.one(&*tx)
.await?)
})
.await
}
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
@@ -167,7 +142,6 @@ impl Database {
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.is_null())
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;

View File

@@ -19,18 +19,6 @@ pub struct Model {
pub created_at: DateTime,
}
impl Model {
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
let period_start = self.stripe_current_period_start?;
chrono::DateTime::from_timestamp(period_start, 0)
}
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
let period_end = self.stripe_current_period_end?;
chrono::DateTime::from_timestamp(period_end, 0)
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
@@ -55,10 +43,6 @@ impl ActiveModelBehavior for ActiveModel {}
pub enum SubscriptionKind {
#[sea_orm(string_value = "zed_pro")]
ZedPro,
#[sea_orm(string_value = "zed_pro_trial")]
ZedProTrial,
#[sea_orm(string_value = "zed_free")]
ZedFree,
}
/// The status of a Stripe subscription.

View File

@@ -183,8 +183,6 @@ pub struct Config {
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>,
pub stripe_zed_pro_trial_price_id: Option<String>,
pub stripe_zed_free_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -203,29 +201,6 @@ impl Config {
}
}
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
}
pub fn zed_pro_trial_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id(
"Zed Pro Trial",
self.stripe_zed_pro_trial_price_id.as_deref(),
)
}
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
}
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
use std::str::FromStr as _;
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
Ok(stripe::PriceId::from_str(price_id)?)
}
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -264,8 +239,6 @@ impl Config {
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,
@@ -351,9 +324,12 @@ impl AppState {
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_billing: stripe_client.clone().map(|stripe_client| {
Arc::new(StripeBilling::new(
stripe_client,
config.stripe_zed_pro_price_id.clone(),
))
}),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,

View File

@@ -2,5 +2,4 @@ use super::*;
pub mod billing_events;
pub mod providers;
pub mod subscription_usages;
pub mod usages;

View File

@@ -1,22 +0,0 @@
use crate::db::UserId;
use super::*;
impl LlmDatabase {
pub async fn get_subscription_usage_for_period(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(&*tx)
.await?)
})
.await
}
}

View File

@@ -2,6 +2,5 @@ pub mod billing_event;
pub mod model;
pub mod monthly_usage;
pub mod provider;
pub mod subscription_usage;
pub mod usage;
pub mod usage_measure;

View File

@@ -1,20 +0,0 @@
use crate::db::UserId;
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub user_id: UserId,
pub period_start_at: PrimitiveDateTime,
pub period_end_at: PrimitiveDateTime,
pub model_requests: i32,
pub edit_predictions: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,5 +1,5 @@
use crate::Cents;
use crate::db::{billing_subscription, user};
use crate::db::user;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::{Config, db::billing_preference};
use anyhow::{Result, anyhow};
@@ -8,7 +8,6 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use util::maybe;
use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
@@ -30,8 +29,6 @@ pub struct LlmTokenClaims {
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
pub plan: rpc::proto::Plan,
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
@@ -42,9 +39,8 @@ impl LlmTokenClaims {
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_legacy_llm_subscription: bool,
has_llm_subscription: bool,
plan: rpc::proto::Plan,
subscription: Option<billing_subscription::Model>,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
@@ -73,7 +69,7 @@ impl LlmTokenClaims {
has_predict_edits_feature_flag: feature_flags
.iter()
.any(|flag| flag == "predict-edits"),
has_llm_subscription: has_legacy_llm_subscription,
has_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
preferences.max_monthly_llm_usage_spending_in_cents as u32
@@ -82,13 +78,6 @@ impl LlmTokenClaims {
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
plan,
subscription_period: maybe!({
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;
let period_end_at = subscription.current_period_end_at()?;
Some((period_start_at.naive_utc(), period_end_at.naive_utc()))
}),
};
Ok(jsonwebtoken::encode(

View File

@@ -4135,8 +4135,7 @@ async fn get_llm_api_token(
Err(anyhow!("terms of service not accepted"))?
}
let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
let has_llm_subscription = session.has_llm_subscription(&db).await?;
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
@@ -4144,9 +4143,8 @@ async fn get_llm_api_token(
session.is_staff(),
billing_preferences,
&flags,
has_legacy_llm_subscription,
has_llm_subscription,
session.current_plan(&db).await?,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,
)?;

View File

@@ -1,16 +1,16 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
use anyhow::Context as _;
use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::PriceId;
use tokio::sync::RwLock;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
zed_pro_price_id: Option<String>,
}
#[derive(Default)]
@@ -32,10 +32,11 @@ struct StripeBillingPrice {
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
Self {
client,
state: RwLock::default(),
zed_pro_price_id,
}
}
@@ -384,19 +385,23 @@ impl StripeBilling {
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_price(
pub async fn checkout_with_zed_pro(
&self,
price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self
.zed_pro_price_id
.as_ref()
.ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(price_id.to_string()),
price: Some(zed_pro_price_id.clone()),
quantity: Some(1),
..Default::default()
}]);

View File

@@ -558,8 +558,6 @@ impl TestServer {
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -12,9 +12,8 @@ use dap::{
};
use futures::{SinkExt as _, channel::mpsc};
use gpui::{
Action, App, AsyncWindowContext, Context, DismissEvent, Entity, EntityId, EventEmitter,
FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, WeakEntity,
actions, anchored, deferred,
Action, App, AsyncWindowContext, Context, Entity, EntityId, EventEmitter, FocusHandle,
Focusable, Subscription, Task, WeakEntity, actions,
};
use project::{
@@ -65,7 +64,6 @@ pub struct DebugPanel {
project: WeakEntity<Project>,
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
_subscriptions: Vec<Subscription>,
}
@@ -128,7 +126,6 @@ impl DebugPanel {
focus_handle: cx.focus_handle(),
project: project.downgrade(),
workspace: workspace.weak_handle(),
context_menu: None,
};
debug_panel
@@ -441,13 +438,7 @@ impl DebugPanel {
else {
return;
};
session.update(cx, |this, cx| {
if let Some(running) = this.mode().as_running() {
running.update(cx, |this, cx| {
this.serialize_layout(window, cx);
});
}
});
let session_id = session.update(cx, |this, cx| this.session_id(cx));
let should_prompt = self
.project
@@ -576,57 +567,6 @@ impl DebugPanel {
)
}
fn deploy_context_menu(
&mut self,
position: Point<Pixels>,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(running_state) = self
.active_session
.as_ref()
.and_then(|session| session.read(cx).mode().as_running().cloned())
{
let pane_items_status = running_state.read(cx).pane_items_status(cx);
let this = cx.weak_entity();
let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| {
for (item_kind, is_visible) in pane_items_status.into_iter() {
menu = menu.toggleable_entry(item_kind, is_visible, IconPosition::End, None, {
let this = this.clone();
move |window, cx| {
this.update(cx, |this, cx| {
if let Some(running_state) =
this.active_session.as_ref().and_then(|session| {
session.read(cx).mode().as_running().cloned()
})
{
running_state.update(cx, |state, cx| {
if is_visible {
state.remove_pane_item(item_kind, window, cx);
} else {
state.add_pane_item(item_kind, position, window, cx);
}
})
}
})
.ok();
}
});
}
menu
});
window.focus(&context_menu.focus_handle(cx));
let subscription = cx.subscribe(&context_menu, |this, _, _: &DismissEvent, cx| {
this.context_menu.take();
cx.notify();
});
self.context_menu = Some((context_menu, position, subscription));
}
}
fn top_controls_strip(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
let active_session = self.active_session.clone();
@@ -951,49 +891,11 @@ impl Render for DebugPanel {
let has_sessions = self.sessions.len() > 0;
debug_assert_eq!(has_sessions, self.active_session.is_some());
if self
.active_session
.as_ref()
.and_then(|session| session.read(cx).mode().as_running().cloned())
.map(|state| state.read(cx).has_open_context_menu(cx))
.unwrap_or(false)
{
self.context_menu.take();
}
v_flex()
.size_full()
.key_context("DebugPanel")
.child(h_flex().children(self.top_controls_strip(window, cx)))
.track_focus(&self.focus_handle(cx))
.when(self.active_session.is_some(), |this| {
this.on_mouse_down(
MouseButton::Right,
cx.listener(|this, event: &MouseDownEvent, window, cx| {
if this
.active_session
.as_ref()
.and_then(|session| {
session.read(cx).mode().as_running().map(|state| {
state.read(cx).has_pane_at_position(event.position)
})
})
.unwrap_or(false)
{
this.deploy_context_menu(event.position, window, cx);
}
}),
)
.children(self.context_menu.as_ref().map(|(menu, position, _)| {
deferred(
anchored()
.position(*position)
.anchor(gpui::Corner::TopLeft)
.child(menu.clone()),
)
.with_priority(1)
}))
})
.map(|this| {
if has_sessions {
this.children(self.active_session.clone())

View File

@@ -1,5 +1,4 @@
use collections::HashMap;
use dap::Capabilities;
use db::kvp::KEY_VALUE_STORE;
use gpui::{Axis, Context, Entity, EntityId, Focusable, Subscription, WeakEntity, Window};
use project::Project;
@@ -10,43 +9,19 @@ use workspace::{Member, Pane, PaneAxis, Workspace};
use crate::session::running::{
self, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
loaded_source_list::LoadedSourceList, module_list::ModuleList,
stack_frame_list::StackFrameList, variable_list::VariableList,
module_list::ModuleList, stack_frame_list::StackFrameList, variable_list::VariableList,
};
#[derive(Clone, Hash, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) enum DebuggerPaneItem {
Console,
Variables,
BreakpointList,
Frames,
Modules,
LoadedSources,
}
impl DebuggerPaneItem {
pub(crate) fn all() -> &'static [DebuggerPaneItem] {
static VARIANTS: &[DebuggerPaneItem] = &[
DebuggerPaneItem::Console,
DebuggerPaneItem::Variables,
DebuggerPaneItem::BreakpointList,
DebuggerPaneItem::Frames,
DebuggerPaneItem::Modules,
DebuggerPaneItem::LoadedSources,
];
VARIANTS
}
pub(crate) fn is_supported(&self, capabilities: &Capabilities) -> bool {
match self {
DebuggerPaneItem::Modules => capabilities.supports_modules_request.unwrap_or_default(),
DebuggerPaneItem::LoadedSources => capabilities
.supports_loaded_sources_request
.unwrap_or_default(),
_ => true,
}
}
pub(crate) fn to_shared_string(self) -> SharedString {
match self {
DebuggerPaneItem::Console => SharedString::new_static("Console"),
@@ -54,17 +29,10 @@ impl DebuggerPaneItem {
DebuggerPaneItem::BreakpointList => SharedString::new_static("Breakpoints"),
DebuggerPaneItem::Frames => SharedString::new_static("Frames"),
DebuggerPaneItem::Modules => SharedString::new_static("Modules"),
DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"),
}
}
}
impl From<DebuggerPaneItem> for SharedString {
fn from(item: DebuggerPaneItem) -> Self {
item.to_shared_string()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct SerializedAxis(pub Axis);
@@ -168,7 +136,6 @@ pub(crate) fn deserialize_pane_layout(
module_list: &Entity<ModuleList>,
console: &Entity<Console>,
breakpoint_list: &Entity<BreakpointList>,
loaded_sources: &Entity<LoadedSourceList>,
subscriptions: &mut HashMap<EntityId, Subscription>,
window: &mut Window,
cx: &mut Context<RunningState>,
@@ -190,7 +157,6 @@ pub(crate) fn deserialize_pane_layout(
module_list,
console,
breakpoint_list,
loaded_sources,
subscriptions,
window,
cx,
@@ -225,7 +191,7 @@ pub(crate) fn deserialize_pane_layout(
.iter()
.map(|child| match child {
DebuggerPaneItem::Frames => Box::new(SubView::new(
stack_frame_list.focus_handle(cx),
pane.focus_handle(cx),
stack_frame_list.clone().into(),
DebuggerPaneItem::Frames,
None,
@@ -246,19 +212,13 @@ pub(crate) fn deserialize_pane_layout(
cx,
)),
DebuggerPaneItem::Modules => Box::new(SubView::new(
module_list.focus_handle(cx),
pane.focus_handle(cx),
module_list.clone().into(),
DebuggerPaneItem::Modules,
None,
cx,
)),
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
loaded_sources.focus_handle(cx),
loaded_sources.clone().into(),
DebuggerPaneItem::LoadedSources,
None,
cx,
)),
DebuggerPaneItem::Console => Box::new(SubView::new(
pane.focus_handle(cx),
console.clone().into(),

View File

@@ -11,12 +11,12 @@ use crate::persistence::{self, DebuggerPaneItem, SerializedPaneLayout};
use super::DebugPanelItemEvent;
use breakpoint_list::BreakpointList;
use collections::{HashMap, IndexMap};
use collections::HashMap;
use console::Console;
use dap::{Capabilities, Thread, client::SessionId, debugger_settings::DebuggerSettings};
use gpui::{
Action as _, AnyView, AppContext, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
NoAction, Pixels, Point, Subscription, Task, WeakEntity,
NoAction, Subscription, Task, WeakEntity,
};
use loaded_source_list::LoadedSourceList;
use module_list::ModuleList;
@@ -49,10 +49,8 @@ pub struct RunningState {
variable_list: Entity<variable_list::VariableList>,
_subscriptions: Vec<Subscription>,
stack_frame_list: Entity<stack_frame_list::StackFrameList>,
loaded_sources_list: Entity<LoadedSourceList>,
module_list: Entity<module_list::ModuleList>,
_module_list: Entity<module_list::ModuleList>,
_console: Entity<Console>,
breakpoint_list: Entity<BreakpointList>,
panes: PaneGroup,
pane_close_subscriptions: HashMap<EntityId, Subscription>,
_schedule_serialize: Option<Task<()>>,
@@ -385,6 +383,7 @@ impl RunningState {
let module_list = cx.new(|cx| ModuleList::new(session.clone(), workspace.clone(), cx));
#[expect(unused)]
let loaded_source_list = cx.new(|cx| LoadedSourceList::new(session.clone(), cx));
let console = cx.new(|cx| {
@@ -397,7 +396,7 @@ impl RunningState {
)
});
let breakpoint_list = BreakpointList::new(session.clone(), workspace.clone(), &project, cx);
let breakpoints = BreakpointList::new(session.clone(), workspace.clone(), &project, cx);
let _subscriptions = vec![
cx.observe(&module_list, |_, _, cx| cx.notify()),
@@ -422,9 +421,6 @@ impl RunningState {
}
cx.notify()
}),
cx.on_focus_out(&focus_handle, window, |this, _, window, cx| {
this.serialize_layout(window, cx);
}),
];
let mut pane_close_subscriptions = HashMap::default();
@@ -437,8 +433,7 @@ impl RunningState {
&variable_list,
&module_list,
&console,
&breakpoint_list,
&loaded_source_list,
&breakpoints,
&mut pane_close_subscriptions,
window,
cx,
@@ -454,7 +449,7 @@ impl RunningState {
&variable_list,
&module_list,
&console,
&breakpoint_list,
breakpoints,
&mut pane_close_subscriptions,
window,
cx,
@@ -474,140 +469,14 @@ impl RunningState {
stack_frame_list,
session_id,
panes,
module_list,
_module_list: module_list,
_console: console,
breakpoint_list,
loaded_sources_list: loaded_source_list,
pane_close_subscriptions,
_schedule_serialize: None,
}
}
pub(crate) fn remove_pane_item(
&mut self,
item_kind: DebuggerPaneItem,
window: &mut Window,
cx: &mut Context<Self>,
) {
debug_assert!(
item_kind.is_supported(self.session.read(cx).capabilities()),
"We should only allow removing supported item kinds"
);
if let Some((pane, item_id)) = self.panes.panes().iter().find_map(|pane| {
Some(pane).zip(
pane.read(cx)
.items()
.find(|item| {
item.act_as::<SubView>(cx)
.is_some_and(|view| view.read(cx).kind == item_kind)
})
.map(|item| item.item_id()),
)
}) {
pane.update(cx, |pane, cx| {
pane.remove_item(item_id, false, true, window, cx)
})
}
}
pub(crate) fn has_pane_at_position(&self, position: Point<Pixels>) -> bool {
self.panes.pane_at_pixel_position(position).is_some()
}
pub(crate) fn add_pane_item(
&mut self,
item_kind: DebuggerPaneItem,
position: Point<Pixels>,
window: &mut Window,
cx: &mut Context<Self>,
) {
debug_assert!(
item_kind.is_supported(self.session.read(cx).capabilities()),
"We should only allow adding supported item kinds"
);
if let Some(pane) = self.panes.pane_at_pixel_position(position) {
let sub_view = match item_kind {
DebuggerPaneItem::Console => {
let weak_console = self._console.clone().downgrade();
Box::new(SubView::new(
pane.focus_handle(cx),
self._console.clone().into(),
item_kind,
Some(Box::new(move |cx| {
weak_console
.read_with(cx, |console, cx| console.show_indicator(cx))
.unwrap_or_default()
})),
cx,
))
}
DebuggerPaneItem::Variables => Box::new(SubView::new(
self.variable_list.focus_handle(cx),
self.variable_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
self.breakpoint_list.focus_handle(cx),
self.breakpoint_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Frames => Box::new(SubView::new(
self.stack_frame_list.focus_handle(cx),
self.stack_frame_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::Modules => Box::new(SubView::new(
self.module_list.focus_handle(cx),
self.module_list.clone().into(),
item_kind,
None,
cx,
)),
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
self.loaded_sources_list.focus_handle(cx),
self.loaded_sources_list.clone().into(),
item_kind,
None,
cx,
)),
};
pane.update(cx, |pane, cx| {
pane.add_item(sub_view, false, false, None, window, cx);
})
}
}
pub(crate) fn pane_items_status(&self, cx: &App) -> IndexMap<DebuggerPaneItem, bool> {
let caps = self.session.read(cx).capabilities();
let mut pane_item_status = IndexMap::from_iter(
DebuggerPaneItem::all()
.iter()
.filter(|kind| kind.is_supported(&caps))
.map(|kind| (*kind, false)),
);
self.panes.panes().iter().for_each(|pane| {
pane.read(cx)
.items()
.filter_map(|item| item.act_as::<SubView>(cx))
.for_each(|view| {
pane_item_status.insert(view.read(cx).kind, true);
});
});
pane_item_status
}
pub(crate) fn serialize_layout(&mut self, window: &mut Window, cx: &mut Context<Self>) {
fn serialize_layout(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self._schedule_serialize.is_none() {
self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| {
cx.background_executor()
@@ -661,10 +530,6 @@ impl RunningState {
}
}
pub(crate) fn has_open_context_menu(&self, cx: &App) -> bool {
self.variable_list.read(cx).has_open_context_menu()
}
pub fn session(&self) -> &Entity<Session> {
&self.session
}
@@ -689,7 +554,7 @@ impl RunningState {
#[cfg(test)]
pub(crate) fn module_list(&self) -> &Entity<ModuleList> {
&self.module_list
&self._module_list
}
#[cfg(test)]
@@ -925,7 +790,7 @@ impl RunningState {
variable_list: &Entity<VariableList>,
module_list: &Entity<ModuleList>,
console: &Entity<Console>,
breakpoints: &Entity<BreakpointList>,
breakpoints: Entity<BreakpointList>,
subscriptions: &mut HashMap<EntityId, Subscription>,
window: &mut Window,
cx: &mut Context<'_, RunningState>,
@@ -949,7 +814,7 @@ impl RunningState {
this.add_item(
Box::new(SubView::new(
breakpoints.focus_handle(cx),
breakpoints.clone().into(),
breakpoints.into(),
DebuggerPaneItem::BreakpointList,
None,
cx,

View File

@@ -3,7 +3,7 @@ use project::debugger::session::{Session, SessionEvent};
use ui::prelude::*;
use util::maybe;
pub(crate) struct LoadedSourceList {
pub struct LoadedSourceList {
list: ListState,
invalidate: bool,
focus_handle: FocusHandle,

View File

@@ -194,10 +194,6 @@ impl VariableList {
}
}
pub(super) fn has_open_context_menu(&self) -> bool {
self.open_context_menu.is_some()
}
fn build_entries(&mut self, cx: &mut Context<Self>) {
let Some(stack_frame_id) = self.selected_stack_frame_id else {
return;

View File

@@ -46,8 +46,7 @@ use workspace::{
actions!(diagnostics, [Deploy, ToggleWarnings]);
#[derive(Default)]
pub(crate) struct IncludeWarnings(bool);
struct IncludeWarnings(bool);
impl Global for IncludeWarnings {}
pub fn init(cx: &mut App) {
@@ -210,7 +209,6 @@ impl ProjectDiagnosticsEditor {
.detach();
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
this.include_warnings = cx.global::<IncludeWarnings>().0;
this.diagnostics.clear();
this.update_all_excerpts(window, cx);
})
.detach();
@@ -302,8 +300,11 @@ impl ProjectDiagnosticsEditor {
}
}
fn toggle_warnings(&mut self, _: &ToggleWarnings, _: &mut Window, cx: &mut Context<Self>) {
cx.set_global(IncludeWarnings(!self.include_warnings));
fn toggle_warnings(&mut self, _: &ToggleWarnings, window: &mut Window, cx: &mut Context<Self>) {
self.include_warnings = !self.include_warnings;
cx.set_global(IncludeWarnings(self.include_warnings));
self.update_all_excerpts(window, cx);
cx.notify();
}
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -380,6 +381,7 @@ impl ProjectDiagnosticsEditor {
Point::zero()..buffer_snapshot.max_point(),
false,
)
.filter(|d| !(d.diagnostic.is_primary && d.diagnostic.is_unnecessary))
.collect::<Vec<_>>();
let unchanged = this.update(cx, |this, _| {
if this.diagnostics.get(&buffer_id).is_some_and(|existing| {
@@ -480,10 +482,7 @@ impl ProjectDiagnosticsEditor {
editor.change_selections(Some(Autoscroll::fit()), window, cx, |s| {
s.select_anchor_ranges([range_to_select]);
})
});
if this.focus_handle.is_focused(window) {
this.editor.read(cx).focus_handle(cx).focus(window);
}
})
}
}

View File

@@ -1,9 +1,9 @@
use super::*;
use collections::{HashMap, HashSet};
use editor::{
DisplayPoint, InlayId,
DisplayPoint,
actions::{GoToDiagnostic, GoToPreviousDiagnostic, MoveToBeginning},
display_map::{DisplayRow, Inlay},
display_map::DisplayRow,
test::{editor_content_with_blocks, editor_test_context::EditorTestContext},
};
use gpui::{TestAppContext, VisualTestContext};
@@ -620,7 +620,7 @@ async fn test_diagnostics_multiple_servers(cx: &mut TestAppContext) {
}
#[gpui::test(iterations = 20)]
async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng) {
async fn test_random_diagnostics(cx: &mut TestAppContext, mut rng: StdRng) {
init_test(cx);
let operations = env::var("OPERATIONS")
@@ -779,162 +779,6 @@ async fn test_random_diagnostics_blocks(cx: &mut TestAppContext, mut rng: StdRng
}
}
// similar to above, but with inlays. Used to find panics when mixing diagnostics and inlays.
#[gpui::test]
async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: StdRng) {
init_test(cx);
let operations = env::var("OPERATIONS")
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(10);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*window, cx);
let workspace = window.root(cx).unwrap();
let mutated_diagnostics = window.build_entity(cx, |window, cx| {
ProjectDiagnosticsEditor::new(true, project.clone(), workspace.downgrade(), window, cx)
});
workspace.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_center(Box::new(mutated_diagnostics.clone()), window, cx);
});
mutated_diagnostics.update_in(cx, |diagnostics, window, _cx| {
assert!(diagnostics.focus_handle.is_focused(window));
});
let mut next_id = 0;
let mut next_filename = 0;
let mut language_server_ids = vec![LanguageServerId(0)];
let mut updated_language_servers = HashSet::default();
let mut current_diagnostics: HashMap<(PathBuf, LanguageServerId), Vec<lsp::Diagnostic>> =
Default::default();
let mut next_inlay_id = 0;
for _ in 0..operations {
match rng.gen_range(0..100) {
// language server completes its diagnostic check
0..=20 if !updated_language_servers.is_empty() => {
let server_id = *updated_language_servers.iter().choose(&mut rng).unwrap();
log::info!("finishing diagnostic check for language server {server_id}");
lsp_store.update(cx, |lsp_store, cx| {
lsp_store.disk_based_diagnostics_finished(server_id, cx)
});
if rng.gen_bool(0.5) {
cx.run_until_parked();
}
}
21..=50 => mutated_diagnostics.update_in(cx, |diagnostics, window, cx| {
diagnostics.editor.update(cx, |editor, cx| {
let snapshot = editor.snapshot(window, cx);
if snapshot.buffer_snapshot.len() > 0 {
let position = rng.gen_range(0..snapshot.buffer_snapshot.len());
let position = snapshot.buffer_snapshot.clip_offset(position, Bias::Left);
log::info!(
"adding inlay at {position}/{}: {:?}",
snapshot.buffer_snapshot.len(),
snapshot.buffer_snapshot.text(),
);
editor.splice_inlays(
&[],
vec![Inlay {
id: InlayId::InlineCompletion(post_inc(&mut next_inlay_id)),
position: snapshot.buffer_snapshot.anchor_before(position),
text: Rope::from(format!("Test inlay {next_inlay_id}")),
}],
cx,
);
}
});
}),
// language server updates diagnostics
_ => {
let (path, server_id, diagnostics) =
match current_diagnostics.iter_mut().choose(&mut rng) {
// update existing set of diagnostics
Some(((path, server_id), diagnostics)) if rng.gen_bool(0.5) => {
(path.clone(), *server_id, diagnostics)
}
// insert a set of diagnostics for a new path
_ => {
let path: PathBuf =
format!(path!("/test/{}.rs"), post_inc(&mut next_filename)).into();
let len = rng.gen_range(128..256);
let content =
RandomCharIter::new(&mut rng).take(len).collect::<String>();
fs.insert_file(&path, content.into_bytes()).await;
let server_id = match language_server_ids.iter().choose(&mut rng) {
Some(server_id) if rng.gen_bool(0.5) => *server_id,
_ => {
let id = LanguageServerId(language_server_ids.len());
language_server_ids.push(id);
id
}
};
(
path.clone(),
server_id,
current_diagnostics.entry((path, server_id)).or_default(),
)
}
};
updated_language_servers.insert(server_id);
lsp_store.update(cx, |lsp_store, cx| {
log::info!("updating diagnostics. language server {server_id} path {path:?}");
randomly_update_diagnostics_for_path(
&fs,
&path,
diagnostics,
&mut next_id,
&mut rng,
);
lsp_store
.update_diagnostics(
server_id,
lsp::PublishDiagnosticsParams {
uri: lsp::Url::from_file_path(&path).unwrap_or_else(|_| {
lsp::Url::parse("file:///test/fallback.rs").unwrap()
}),
diagnostics: diagnostics.clone(),
version: None,
},
&[],
cx,
)
.unwrap()
});
cx.executor()
.advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10));
cx.run_until_parked();
}
}
}
log::info!("updating mutated diagnostics view");
mutated_diagnostics.update_in(cx, |diagnostics, window, cx| {
diagnostics.update_stale_excerpts(window, cx)
});
cx.executor()
.advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10));
cx.run_until_parked();
}
#[gpui::test]
async fn active_diagnostics_dismiss_after_invalidation(cx: &mut TestAppContext) {
init_test(cx);

View File

@@ -9,7 +9,7 @@ use language::Diagnostic;
use ui::{Button, ButtonLike, Color, Icon, IconName, Label, Tooltip, h_flex, prelude::*};
use workspace::{StatusItemView, ToolbarItemEvent, Workspace, item::ItemHandle};
use crate::{Deploy, IncludeWarnings, ProjectDiagnosticsEditor};
use crate::{Deploy, ProjectDiagnosticsEditor};
pub struct DiagnosticIndicator {
summary: project::DiagnosticSummary,
@@ -94,11 +94,6 @@ impl Render for DiagnosticIndicator {
})
.on_click(cx.listener(|this, _, window, cx| {
if let Some(workspace) = this.workspace.upgrade() {
if this.summary.error_count == 0 && this.summary.warning_count > 0 {
cx.update_default_global(
|show_warnings: &mut IncludeWarnings, _| show_warnings.0 = true,
);
}
workspace.update(cx, |workspace, cx| {
ProjectDiagnosticsEditor::deploy(
workspace,

View File

@@ -49,8 +49,8 @@ use language::{
};
use lsp::DiagnosticSeverity;
use multi_buffer::{
Anchor, AnchorRangeExt, ExcerptId, MultiBuffer, MultiBufferPoint, MultiBufferRow,
MultiBufferSnapshot, RowInfo, ToOffset, ToPoint,
Anchor, AnchorRangeExt, MultiBuffer, MultiBufferPoint, MultiBufferRow, MultiBufferSnapshot,
RowInfo, ToOffset, ToPoint,
};
use serde::Deserialize;
use std::{
@@ -574,21 +574,6 @@ impl DisplayMap {
self.block_map.read(snapshot, edits);
}
pub fn remove_inlays_for_excerpts(&mut self, excerpts_removed: &[ExcerptId]) {
let to_remove = self
.inlay_map
.current_inlays()
.filter_map(|inlay| {
if excerpts_removed.contains(&inlay.position.excerpt_id) {
Some(inlay.id)
} else {
None
}
})
.collect::<Vec<_>>();
self.inlay_map.splice(&to_remove, Vec::new());
}
fn tab_size(buffer: &Entity<MultiBuffer>, cx: &App) -> NonZeroU32 {
let buffer = buffer.read(cx).as_singleton().map(|buffer| buffer.read(cx));
let language = buffer

View File

@@ -36,7 +36,7 @@ enum Transform {
#[derive(Debug, Clone)]
pub struct Inlay {
pub id: InlayId,
pub(crate) id: InlayId,
pub position: Anchor,
pub text: text::Rope,
}
@@ -482,9 +482,6 @@ impl InlayMap {
};
for inlay in &self.inlays[start_ix..] {
if !inlay.position.is_valid(&buffer_snapshot) {
continue;
}
let buffer_offset = inlay.position.to_offset(&buffer_snapshot);
if buffer_offset > buffer_edit.new.end {
break;
@@ -497,7 +494,9 @@ impl InlayMap {
buffer_snapshot.text_summary_for_range(prefix_start..prefix_end),
);
new_transforms.push(Transform::Inlay(inlay.clone()), &());
if inlay.position.is_valid(&buffer_snapshot) {
new_transforms.push(Transform::Inlay(inlay.clone()), &());
}
}
// Apply the rest of the edit.

View File

@@ -4170,13 +4170,10 @@ impl Editor {
if let Some(InlaySplice {
to_remove,
to_insert,
}) = self.inlay_hint_cache.remove_excerpts(&excerpts_removed)
}) = self.inlay_hint_cache.remove_excerpts(excerpts_removed)
{
self.splice_inlays(&to_remove, to_insert, cx);
}
self.display_map.update(cx, |display_map, _| {
display_map.remove_inlays_for_excerpts(&excerpts_removed)
});
return;
}
InlayHintRefreshReason::NewLinesShown => (InvalidationStrategy::None, None),
@@ -4740,8 +4737,8 @@ impl Editor {
let lookahead = replace_range
.end
.saturating_sub(newest_anchor.end.text_anchor.to_offset(buffer));
let prefix = &old_text[..old_text.len().saturating_sub(lookahead)];
let suffix = &old_text[lookbehind.min(old_text.len())..];
let prefix = &old_text[..old_text.len() - lookahead];
let suffix = &old_text[lookbehind..];
let selections = self.selections.all::<usize>(cx);
let mut edits = Vec::new();
@@ -4756,7 +4753,7 @@ impl Editor {
// if prefix is present, don't duplicate it
if snapshot.contains_str_at(range.start.saturating_sub(lookbehind), prefix) {
text = &new_text[lookbehind.min(new_text.len())..];
text = &new_text[lookbehind..];
// if suffix is also present, mimic the newest cursor and replace it
if selection.id != newest_anchor.id
@@ -13726,6 +13723,8 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<Task<Result<Navigated>>> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction);
let selection = self.selections.newest::<usize>(cx);
let multi_buffer = self.buffer.read(cx);
let head = selection.head();

View File

@@ -555,12 +555,12 @@ impl InlayHintCache {
/// Completely forget of certain excerpts that were removed from the multibuffer.
pub(super) fn remove_excerpts(
&mut self,
excerpts_removed: &[ExcerptId],
excerpts_removed: Vec<ExcerptId>,
) -> Option<InlaySplice> {
let mut to_remove = Vec::new();
for excerpt_to_remove in excerpts_removed {
self.update_tasks.remove(excerpt_to_remove);
if let Some(cached_hints) = self.hints.remove(excerpt_to_remove) {
self.update_tasks.remove(&excerpt_to_remove);
if let Some(cached_hints) = self.hints.remove(&excerpt_to_remove) {
let cached_hints = cached_hints.read();
to_remove.extend(cached_hints.ordered_hints.iter().copied());
}
@@ -989,16 +989,6 @@ fn fetch_and_update_hints(
}
let buffer = editor.buffer().read(cx).buffer(query.buffer_id)?;
if !editor.registered_buffers.contains_key(&query.buffer_id) {
if let Some(project) = editor.project.as_ref() {
project.update(cx, |project, cx| {
editor.registered_buffers.insert(
query.buffer_id,
project.register_buffer_with_language_servers(&buffer, cx),
);
})
}
}
editor
.semantics_provider
.as_ref()?

View File

@@ -16,7 +16,6 @@ client.workspace = true
collections.workspace = true
context_server.workspace = true
dap.workspace = true
dirs = "5.0"
env_logger.workspace = true
extension.workspace = true
fs.workspace = true
@@ -28,7 +27,7 @@ language.workspace = true
language_extension.workspace = true
language_model.workspace = true
language_models.workspace = true
languages.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
node_runtime.workspace = true
paths.workspace = true
project.workspace = true
@@ -36,13 +35,12 @@ prompt_store.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
telemetry.workspace = true
toml.workspace = true
unindent.workspace = true
util.workspace = true
uuid = { version = "1.6", features = ["v4"] }
workspace-hack.workspace = true
[[bin]]

View File

@@ -1,3 +1,3 @@
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait.
The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.

View File

@@ -1,16 +1,13 @@
mod example;
mod ids;
use client::{Client, ProxySettings, UserStore};
pub(crate) use example::*;
use telemetry;
use ::fs::RealFs;
use anyhow::{Result, anyhow};
use clap::Parser;
use extension::ExtensionHostProxy;
use futures::future;
use futures::stream::StreamExt;
use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
use gpui_tokio::Tokio;
@@ -42,18 +39,9 @@ struct Args {
/// Model to use (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model: String,
/// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
#[arg(long, value_delimiter = ',')]
languages: Option<Vec<String>>,
/// How many times to run each example. Note that this is currently not very efficient as N
/// worktrees will be created for the examples.
#[arg(long, default_value = "1")]
repetitions: u32,
/// How many times to run the judge on each example run.
#[arg(long, default_value = "3")]
judge_repetitions: u32,
/// Maximum number of examples to run concurrently.
#[arg(long, default_value = "10")]
concurrency: usize,
}
fn main() {
@@ -86,15 +74,6 @@ fn main() {
app.run(move |cx| {
let app_state = init(cx);
let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
let session_id = uuid::Uuid::new_v4().to_string();
app_state
.client
.telemetry()
.start(system_id, installation_id, session_id, cx);
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
@@ -150,20 +129,12 @@ fn main() {
continue;
}
// TODO: This creates a worktree per repetition. Ideally these examples should
// either be run sequentially on the same worktree, or reuse worktrees when there
// are more examples to run than the concurrency limit.
for repetition_number in 0..args.repetitions {
let mut example = example.clone();
example.set_repetition_number(repetition_number);
let name_len = example.name.len();
if name_len > max_name_width {
max_name_width = example.name.len();
}
examples.push(example);
let name_len = example.name.len();
if name_len > max_name_width {
max_name_width = example.name.len();
}
examples.push(example);
}
println!("Skipped examples: {}\n", skipped.join(", "));
@@ -183,7 +154,7 @@ fn main() {
println!(
"{}Logging to: {}",
example.log_prefix,
example.output_file_path.display()
example.run_directory_path.display()
);
let repo_url = example.base.url.clone();
@@ -232,26 +203,18 @@ fn main() {
example.setup().await?;
}
let judge_repetitions = args.judge_repetitions;
let concurrency = args.concurrency;
let tasks = examples
.into_iter()
.map(|example| {
let app_state = app_state.clone();
let model = model.clone();
cx.spawn(async move |cx| {
let result =
run_example(&example, model, app_state, judge_repetitions, cx).await;
(result, example)
(run_example(&example, model, app_state, cx).await, example)
})
})
.collect::<Vec<_>>();
let results = futures::stream::iter(tasks)
.buffer_unordered(concurrency)
.collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
.await;
let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
println!("\n\n");
println!("========================================");
@@ -260,37 +223,32 @@ fn main() {
println!("");
let mut judge_scores = Vec::new();
let mut errors = 0;
let mut successes = 0;
for (result, example) in results {
match result {
Err(err) => {
errors += 1;
println!("💥 {}{:?}", example.log_prefix, err);
}
Ok(judge_results) => {
for judge_result in judge_results {
match judge_result {
Ok(judge_output) => {
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
let score: u32 = judge_output.score;
let score_index = (score.min(5)) as usize;
Ok(judge_output) => {
successes += 1;
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
println!(
"{} {}{}",
SCORES[score_index], example.log_prefix, judge_output.score,
);
judge_scores.push(judge_output.score);
}
Err(err) => {
println!("💥 {}{:?}", example.log_prefix, err);
}
}
}
println!(
"{} {}{}",
SCORES[judge_output.score.min(5) as usize],
example.log_prefix,
judge_output.score,
);
judge_scores.push(judge_output.score);
}
}
println!(
"{} > {}",
" ".repeat(max_name_width),
example.output_file_path.display()
example.run_directory_path.display()
);
}
@@ -300,11 +258,12 @@ fn main() {
.map(|score| score as f32)
.sum::<f32>()
/ (score_count as f32);
println!("\nAverage score: {average_score}");
std::thread::sleep(std::time::Duration::from_secs(2));
app_state.client.telemetry().flush_events();
if errors > 0 {
println!("\n{errors} example(s) errored out. Average score among the {successes} example(s) that didn't error: {average_score}");
} else {
println!("\nAll {successes} examples ran successfully. Average score: {average_score}");
}
cx.update(|cx| cx.quit())
})
@@ -316,55 +275,12 @@ async fn run_example(
example: &Example,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
judge_repetitions: u32,
cx: &mut AsyncApp,
) -> Result<Vec<Result<JudgeOutput>>> {
let run_output = cx
.update(|cx| example.run(model.clone(), app_state.clone(), cx))?
) -> Result<JudgeOutput> {
cx.update(|cx| example.run(model.clone(), app_state, cx))?
.await?;
let diff = example.repository_diff().await?;
// Run judge for each repetition
let mut results = Vec::new();
for round in 0..judge_repetitions {
let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
if let Ok(judge_output) = &judge_result {
let cohort_id = example
.output_file_path
.parent()
.and_then(|p| p.file_name())
.map(|name| name.to_string_lossy().to_string())
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
let path = std::path::Path::new(".");
let commit_id = get_current_commit_id(path).await.unwrap_or_default();
telemetry::event!(
"Agent Eval Completed",
cohort_id = cohort_id,
example_name = example.name.clone(),
round = round,
score = judge_output.score,
analysis = judge_output.analysis,
tool_use_counts = run_output.tool_use_counts,
response_count = run_output.response_count,
token_usage = run_output.token_usage,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
repository_url = example.base.url.clone(),
repository_revision = example.base.revision.clone(),
diagnostics_summary = run_output.diagnostics,
commit_id = commit_id
);
}
results.push(judge_result);
}
app_state.client.telemetry().flush_events();
Ok(results)
example.judge(model, diff, cx).await
}
fn list_all_examples() -> Result<Vec<PathBuf>> {
@@ -526,13 +442,3 @@ pub fn authenticate_model_provider(
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}
pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
(run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
}
pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
futures::executor::block_on(async {
get_current_commit_id(repo_path).await.unwrap_or_default()
})
}

View File

@@ -8,11 +8,12 @@ use futures::channel::mpsc;
use futures::{FutureExt, StreamExt as _, select_biased};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
use handlebars::Handlebars;
use language::{DiagnosticSeverity, OffsetRangeExt};
use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
use language_model::{
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
StopReason, TokenUsage,
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role, StopReason, TokenUsage,
};
use project::lsp_store::LanguageServerState;
use project::{LspStore, Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::fmt::Write as _;
@@ -47,6 +48,17 @@ pub struct ExampleBase {
pub require_lsp: bool,
}
impl ExampleBase {
pub fn repo_name(&self) -> String {
self.url
.split('/')
.last()
.unwrap_or(&"")
.trim_end_matches(".git")
.into()
}
}
#[derive(Clone, Debug)]
pub struct Example {
pub name: String,
@@ -56,12 +68,8 @@ pub struct Example {
pub prompt: String,
/// Content of `criteria.md`
pub criteria: String,
/// Markdown output file to append to
pub output_file: Option<Arc<Mutex<File>>>,
/// Path to the output run directory.
pub run_dir: PathBuf,
/// Path to markdown output file
pub output_file_path: PathBuf,
/// Path to the directory containing the requests and responses for the agentic loop
pub run_directory_path: PathBuf,
/// Prefix used for logging that identifies this example
pub log_prefix: String,
}
@@ -94,27 +102,17 @@ impl Example {
let base_path = dir_path.join("base.toml");
let prompt_path = dir_path.join("prompt.md");
let criteria_path = dir_path.join("criteria.md");
let output_file_path = run_dir.join(format!("{}.md", name));
Ok(Example {
name: name.clone(),
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
prompt: fs::read_to_string(prompt_path.clone())?,
criteria: fs::read_to_string(criteria_path.clone())?,
run_dir: run_dir.to_path_buf(),
output_file: None,
output_file_path,
run_directory_path: run_dir.to_path_buf(),
log_prefix: name,
})
}
pub fn set_repetition_number(&mut self, repetition_number: u32) {
if repetition_number > 0 {
self.name = format!("{}-{}", self.name, repetition_number);
self.output_file_path = self.run_dir.join(format!("{}.md", self.name));
}
}
pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
self.log_prefix = format!(
"{}{:<width$}\x1b[0m | ",
@@ -134,27 +132,20 @@ impl Example {
.context(format!("No such directory {WORKTREES_DIR}"))
.unwrap()
.join(&self.name)
.join(self.base.repo_name())
}
/// Set up the example by checking out the specified Git revision
pub async fn setup(&mut self) -> Result<()> {
let repo_path = repo_path_for_url(&self.base.url);
let revision_exists = run_git(&repo_path, &["rev-parse", "--verify", &self.base.revision])
.await
.is_ok();
println!("{}Fetching", self.log_prefix);
if !revision_exists {
println!(
"{}Fetching revision {}",
self.log_prefix, &self.base.revision
);
run_git(
&repo_path,
&["fetch", "--depth", "1", "origin", &self.base.revision],
)
.await?;
}
run_git(
&repo_path,
&["fetch", "--depth", "1", "origin", &self.base.revision],
)
.await?;
let worktree_path = self.worktree_path();
@@ -184,20 +175,9 @@ impl Example {
.await?;
}
// Create the output file
let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?));
self.output_file = Some(output_file);
Ok(())
}
/// Returns the output file, panicking if it's not set
fn output_file(&self) -> Arc<Mutex<File>> {
self.output_file
.clone()
.expect("Output file not created. Call setup() first.")
}
pub fn run(
&self,
model: Arc<dyn LanguageModel>,
@@ -276,28 +256,26 @@ impl Example {
cx.background_executor().timer(Duration::new(5, 0)).await;
wait_for_lang_server(&lsp_store, this.log_prefix.clone(), cx).await?;
lsp_store.update(cx, |lsp_store, cx| {
lsp_open_handle.update(cx, |buffer, cx| {
buffer.update(cx, |buffer, cx| {
let has_language_server = lsp_store
.language_servers_for_local_buffer(buffer, cx)
.next()
.is_some();
if has_language_server {
Ok(())
} else {
Err(anyhow!(
"`{:?}` was opened to cause the language server to start, \
but no language servers are registered for its buffer. \
Set `require_lsp = false` in `base.toml` to skip this.",
language_file
))
}
})
})
})??;
// Retry up to 10 times, with a delay in between, for the language server to
// transition from the Starting to Running state.
const LS_START_ATTEMPTS: usize = 10;
const DELAY_BETWEEN_ATTEMPTS: Duration = Duration::new(1, 0);
let mut answer = None;
Some((lsp_open_handle, lsp_store))
for _ in 0..LS_START_ATTEMPTS {
if any_running(&language_file, lsp_store.clone(), lsp_open_handle.clone(), cx).await? {
answer = Some((lsp_open_handle, lsp_store));
break;
}
cx.background_executor().timer(DELAY_BETWEEN_ATTEMPTS).await;
}
if answer.is_none() {
return Err(anyhow!("Timed out waiting for language server to transition from Starting to Running state."));
}
answer
} else {
None
};
@@ -310,14 +288,18 @@ impl Example {
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
{
let output_file_ref = this.output_file();
let mut output_file = output_file_ref.lock().unwrap();
writeln!(&mut output_file, "👤 USER:").log_err();
writeln!(&mut output_file, "{}", this.prompt).log_err();
writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
output_file.flush().log_err();
}
thread.update(cx, |thread, _cx| {
let mut request_count = 0;
let run_dir_path = this.run_directory_path.clone();
thread.set_request_callback(move |request, response_events| {
request_count += 1;
let tools_file_path = run_dir_path.join(format!("{request_count}.tools.md"));
let messages_file_path = run_dir_path.join(format!("{request_count}.messages.md"));
let markdown = RequestMarkdown::new(request, response_events);
fs::write(tools_file_path, markdown.tools).expect("failed to write tools file");
fs::write(messages_file_path, markdown.messages).expect("failed to write messages file");
});
})?;
let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
Mutex::new(HashMap::default()).into();
@@ -330,7 +312,6 @@ impl Example {
let event_handler_task = cx.spawn({
// Need to clone the Arc here because the reference from output_file() won't live long enough
let output_file = this.output_file.clone().unwrap();
let log_prefix = this.log_prefix.clone();
let tool_use_counts = tool_use_counts.clone();
let thread = thread.downgrade();
@@ -346,8 +327,6 @@ impl Example {
return Err(anyhow!("ThreadEvent channel ended early"));
};
let mut output_file = output_file.lock().unwrap();
match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
@@ -368,44 +347,23 @@ impl Example {
ThreadEvent::ShowError(thread_error) => {
break Err(anyhow!(thread_error.clone()));
}
ThreadEvent::StreamedAssistantText(_, chunk) => {
write!(&mut output_file, "{}", chunk).log_err();
}
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
write!(&mut output_file, "{}", chunk).log_err();
}
ThreadEvent::UsePendingTools { tool_uses } => {
writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
for tool_use in tool_uses {
writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
.log_err();
}
ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(tool_use) = pending_tool_use {
let message = format!("TOOL FINISHED: {}", tool_use.name);
println!("{}{message}", log_prefix);
}
thread.update(cx, |thread, _cx| {
if let Some(tool_use) = pending_tool_use {
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
let message = if tool_result.is_error {
format!("TOOL FAILED: {}", tool_use.name)
} else {
format!("TOOL FINISHED: {}", tool_use.name)
};
println!("{log_prefix}{message}");
writeln!(&mut output_file, "\n{}", message).log_err();
writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
let mut tool_use_counts = tool_use_counts.lock().unwrap();
*tool_use_counts
.entry(tool_result.tool_name.clone())
.or_insert(0) += 1;
} else {
let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
println!("{log_prefix}{message}");
writeln!(&mut output_file, "\n{}", message).log_err();
}
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
let mut tool_use_counts = tool_use_counts.lock().unwrap();
*tool_use_counts
.entry(tool_result.tool_name.clone())
.or_insert(0) += 1;
}
})?;
}
@@ -424,8 +382,6 @@ impl Example {
}
}
}
output_file.flush().log_err();
}
}
});
@@ -447,10 +403,6 @@ impl Example {
println!("{}Getting repository diff", this.log_prefix);
let repository_diff = this.repository_diff().await?;
let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name));
let mut repository_diff_output_file = File::create(&repository_diff_path)?;
writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
println!("{}Getting diagnostics", this.log_prefix);
let diagnostics = cx
.update(move |cx| {
@@ -482,9 +434,18 @@ impl Example {
&self,
model: Arc<dyn LanguageModel>,
repository_diff: String,
judge_repetitions: u32,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
let mut output_file = File::create(self.run_directory_path.join("judge.md"))
.expect("failed to create judge.md");
{
writeln!(&mut output_file, "\n\n").log_err();
writeln!(&mut output_file, "========================================").log_err();
writeln!(&mut output_file, " REPOSITORY DIFF ").log_err();
writeln!(&mut output_file, "========================================").log_err();
writeln!(&mut output_file, "\n{}", &repository_diff).log_err();
}
let judge_prompt = include_str!("judge_prompt.hbs");
let judge_prompt_name = "judge_prompt";
let mut handlebars = Handlebars::new();
@@ -510,14 +471,11 @@ impl Example {
let response = send_language_model_request(model, request, cx).await?;
let judge_file_path = self.run_dir.join(format!(
"{}_judge_{}.md",
self.name, // This is the eval_name
judge_repetitions
));
let mut judge_output_file = File::create(&judge_file_path)?;
writeln!(&mut judge_output_file, "{}", &response).log_err();
writeln!(&mut output_file, "\n\n").log_err();
writeln!(&mut output_file, "========================================").log_err();
writeln!(&mut output_file, " JUDGE OUTPUT ").log_err();
writeln!(&mut output_file, "========================================").log_err();
writeln!(&mut output_file, "\n{}", &response).log_err();
parse_judge_output(&response)
}
@@ -593,6 +551,55 @@ fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool
.any(|(_, status)| !status.pending_work.is_empty())
}
async fn any_running(
language_file: &ProjectPath,
lsp_store: Entity<LspStore>,
lsp_open_handle: Entity<Entity<Buffer>>,
cx: &mut AsyncApp,
) -> Result<bool> {
lsp_store.update(cx, |lsp_store, cx| {
lsp_open_handle.update(cx, |buffer, cx| {
buffer.update(cx, |buffer, cx| {
match lsp_store.language_server_state_for_local_buffer(buffer, cx) {
Some(states) => {
let mut any_starting = false;
for state in states {
match state {
LanguageServerState::Starting { .. } => {
// A server in the "starting" state means we should keep waiting for
// it to advance to the "running" state.
any_starting = true;
},
LanguageServerState::Running { .. } => {
// We found one that's running, so we're done.
return Ok(true);
}
}
}
if any_starting {
Ok(false)
} else {
Err(anyhow!(
"`{language_file:?}` was opened to cause the language server to start, \
but no language servers are registered for its buffer. \
Set `require_lsp = false` in `base.toml` to skip using a language server for this file.",
))
}
}
None => {
Err(anyhow!(
"`{language_file:?}` was opened locally to cause the language server to start, \
but the language server's mode was not set to LspStoreMode::Local."
))
}
}
})
})
})?
}
async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
let paths_with_diagnostics = project.update(cx, |project, cx| {
project
@@ -718,6 +725,129 @@ pub async fn send_language_model_request(
}
}
struct RequestMarkdown {
tools: String,
messages: String,
}
impl RequestMarkdown {
fn new(
request: &LanguageModelRequest,
response_events: &[Result<LanguageModelCompletionEvent, String>],
) -> Self {
let mut tools = String::new();
let mut messages = String::new();
// Print the tools
if !request.tools.is_empty() {
for tool in &request.tools {
write!(&mut tools, "# {}\n\n", tool.name).unwrap();
write!(&mut tools, "{}\n\n", tool.description).unwrap();
write!(
&mut tools,
"```json\n{}\n```\n\n",
serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default()
)
.unwrap();
}
}
// Print the messages
for message in &request.messages {
let role_str = match message.role {
Role::User => "👤 USER",
Role::Assistant => "🤖 ASSISTANT",
Role::System => "⚙️ SYSTEM",
};
messages.push_str(&format!("# {}\n\n", role_str));
for content in &message.content {
match content {
MessageContent::Text(text) => {
messages.push_str(text);
messages.push_str("\n\n");
}
MessageContent::Image(_) => {
messages.push_str("[IMAGE DATA]\n\n");
}
MessageContent::ToolUse(tool_use) => {
messages.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
tool_use.name, tool_use.id
));
messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
}
MessageContent::ToolResult(tool_result) => {
messages.push_str(&format!(
"**Tool Result**: {} (ID: {})\n",
tool_result.tool_name, tool_result.tool_use_id
));
if tool_result.is_error {
messages.push_str("**ERROR:**\n");
}
messages.push_str(&format!("```\n{}\n```\n\n", tool_result.content));
}
}
}
}
// Print the response events if any
if !response_events.is_empty() {
messages.push_str("# Response\n\n");
let mut text_buffer = String::new();
let mut thinking_buffer = String::new();
let flush_buffers =
|output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
if !text_buffer.is_empty() {
output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
text_buffer.clear();
}
if !thinking_buffer.is_empty() {
output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
thinking_buffer.clear();
}
};
for event in response_events {
match event {
Ok(LanguageModelCompletionEvent::Text(text)) => {
text_buffer.push_str(text);
}
Ok(LanguageModelCompletionEvent::Thinking(text)) => {
thinking_buffer.push_str(text);
}
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
flush_buffers(&mut messages, &mut text_buffer, &mut thinking_buffer);
messages.push_str(&format!("**Stop**: {:?}\n\n", reason));
}
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
flush_buffers(&mut messages, &mut text_buffer, &mut thinking_buffer);
messages.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
tool_use.name, tool_use.id
));
messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
| LanguageModelCompletionEvent::StartMessage { .. },
) => {}
Err(error) => {
flush_buffers(&mut messages, &mut text_buffer, &mut thinking_buffer);
messages.push_str(&format!("**Error**: {}\n\n", error));
}
}
}
flush_buffers(&mut messages, &mut text_buffer, &mut thinking_buffer);
}
Self { tools, messages }
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@@ -1,28 +0,0 @@
use anyhow::Result;
use std::fs;
use std::path::{Path, PathBuf};
use uuid::Uuid;
pub fn get_or_create_id(path: &Path) -> Result<String> {
if let Ok(id) = fs::read_to_string(path) {
let trimmed = id.trim();
if !trimmed.is_empty() {
return Ok(trimmed.to_string());
}
}
let new_id = Uuid::new_v4().to_string();
fs::write(path, &new_id)?;
Ok(new_id)
}
pub fn eval_system_id_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("zed-eval-system-id")
}
pub fn eval_installation_id_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("zed-eval-installation-id")
}

View File

@@ -84,11 +84,6 @@ impl FeatureFlag for ZedPro {
const NAME: &'static str = "zed-pro";
}
pub struct ZedProWebSearchTool {}
impl FeatureFlag for ZedProWebSearchTool {
const NAME: &'static str = "zed-pro-web-search-tool";
}
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {

View File

@@ -2,7 +2,6 @@ use gpui::{App, ClipboardItem, PromptLevel, actions};
use system_specs::SystemSpecs;
use util::ResultExt;
use workspace::Workspace;
use zed_actions::feedback::FileBugReport;
pub mod feedback_modal;
@@ -13,6 +12,7 @@ actions!(
[
CopySystemSpecsIntoClipboard,
EmailZed,
FileBugReport,
OpenZedRepo,
RequestFeature,
]
@@ -27,7 +27,7 @@ fn file_bug_report_url(specs: &SystemSpecs) -> String {
concat!(
"https://github.com/zed-industries/zed/issues/new",
"?",
"template=10_bug_report.yml",
"template=1_bug_report.yml",
"&",
"environment={}"
),

View File

@@ -1333,23 +1333,13 @@ impl FakeFs {
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
anyhow::bail!("pointed-to git dir {path:?} not found")
};
let FakeFsEntry::Dir {
git_repo_state,
entries,
..
} = &mut *git_dir_entry.lock()
else {
let FakeFsEntry::Dir { git_repo_state, .. } = &mut *git_dir_entry.lock() else {
anyhow::bail!("gitfile points to a non-directory")
};
let common_dir = if let Some(child) = entries.get("commondir") {
Path::new(
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
.context("commondir content")?,
)
.to_owned()
} else {
canonical_path.clone()
};
let common_dir = canonical_path
.ancestors()
.find(|ancestor| ancestor.ends_with(".git"))
.ok_or_else(|| anyhow!("repository dir not contained in any .git"))?;
let repo_state = git_repo_state.get_or_insert_with(|| {
Arc::new(Mutex::new(FakeGitRepositoryState::new(
state.git_event_tx.clone(),
@@ -1357,7 +1347,7 @@ impl FakeFs {
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, &canonical_path, &common_dir);
let result = f(&mut repo_state, &canonical_path, common_dir);
if emit_git_event {
state.emit_event([(canonical_path, None)]);

View File

@@ -1013,6 +1013,7 @@ impl GitRepository for RealGitRepository {
let mut command = new_smol_command("git");
command
.envs(env.iter())
.env("GIT_HTTP_USER_AGENT", "Zed")
.current_dir(&working_directory)
.args(["push"])
.args(options.map(|option| match option {
@@ -1044,6 +1045,7 @@ impl GitRepository for RealGitRepository {
let mut command = new_smol_command("git");
command
.envs(env.iter())
.env("GIT_HTTP_USER_AGENT", "Zed")
.current_dir(&working_directory?)
.args(["pull"])
.arg(remote_name)
@@ -1068,6 +1070,7 @@ impl GitRepository for RealGitRepository {
let mut command = new_smol_command("git");
command
.envs(env.iter())
.env("GIT_HTTP_USER_AGENT", "Zed")
.current_dir(&working_directory?)
.args(["fetch", "--all"])
.stdout(smol::process::Stdio::piped())

View File

@@ -599,11 +599,33 @@ impl GitPanel {
}
pub fn entry_by_path(&self, path: &RepoPath) -> Option<usize> {
fn binary_search<F>(mut low: usize, mut high: usize, is_target: F) -> Option<usize>
where
F: Fn(usize) -> std::cmp::Ordering,
{
while low < high {
let mid = low + (high - low) / 2;
match is_target(mid) {
std::cmp::Ordering::Equal => return Some(mid),
std::cmp::Ordering::Less => low = mid + 1,
std::cmp::Ordering::Greater => high = mid,
}
}
None
}
if self.conflicted_count > 0 {
let conflicted_start = 1;
if let Ok(ix) = self.entries[conflicted_start..conflicted_start + self.conflicted_count]
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
{
if let Some(ix) = binary_search(
conflicted_start,
conflicted_start + self.conflicted_count,
|ix| {
self.entries[ix]
.status_entry()
.unwrap()
.repo_path
.cmp(&path)
},
) {
return Some(ix);
}
}
@@ -613,8 +635,14 @@ impl GitPanel {
} else {
0
} + 1;
if let Ok(ix) = self.entries[tracked_start..tracked_start + self.tracked_count]
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
if let Some(ix) =
binary_search(tracked_start, tracked_start + self.tracked_count, |ix| {
self.entries[ix]
.status_entry()
.unwrap()
.repo_path
.cmp(&path)
})
{
return Some(ix);
}
@@ -629,8 +657,14 @@ impl GitPanel {
} else {
0
} + 1;
if let Ok(ix) = self.entries[untracked_start..untracked_start + self.new_count]
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
if let Some(ix) =
binary_search(untracked_start, untracked_start + self.new_count, |ix| {
self.entries[ix]
.status_entry()
.unwrap()
.repo_path
.cmp(&path)
})
{
return Some(ix);
}
@@ -3577,15 +3611,6 @@ impl GitPanel {
items
}
})
.when(
!self.horizontal_scrollbar.show_track
&& self.horizontal_scrollbar.show_scrollbar,
|this| {
// when not showing the horizontal scrollbar track, make sure we don't
// obscure the last entry
this.pb(scroll_track_size)
},
)
.size_full()
.flex_grow()
.with_sizing_behavior(ListSizingBehavior::Auto)

View File

@@ -125,7 +125,7 @@ pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
pub contents: Vec<Content>,
pub system_instruction: Option<SystemInstruction>,
pub system_instructions: Option<SystemInstructions>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -162,7 +162,7 @@ pub struct Content {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SystemInstruction {
pub struct SystemInstructions {
pub parts: Vec<Part>,
}

View File

@@ -589,6 +589,11 @@ impl<V> Entity<V> {
use postage::prelude::{Sink as _, Stream as _};
let (tx, mut rx) = postage::mpsc::channel(1024);
let timeout_duration = if cfg!(target_os = "macos") {
Duration::from_millis(100)
} else {
Duration::from_secs(1)
};
let mut cx = cx.app.borrow_mut();
let subscriptions = (
@@ -610,7 +615,7 @@ impl<V> Entity<V> {
let handle = self.downgrade();
async move {
crate::util::timeout(Duration::from_secs(1), async move {
crate::util::timeout(timeout_duration, async move {
loop {
{
let cx = cx.borrow();

View File

@@ -27,8 +27,6 @@ use objc::{
};
use std::{cell::RefCell, ffi::c_void, mem, ptr, rc::Rc};
use super::NSStringExt;
#[derive(Clone)]
pub struct MacScreenCaptureSource {
sc_display: id,
@@ -186,10 +184,7 @@ pub(crate) fn get_sources() -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptur
Ok(result)
} else {
let msg: id = msg_send![error, localizedDescription];
Err(anyhow!(
"Screen share failed: {:?}",
NSStringExt::to_str(&msg)
))
Err(anyhow!("Failed to register: {:?}", msg))
};
tx.send(result).ok();
});

View File

@@ -97,10 +97,7 @@ pub struct TokenUsage {
impl TokenUsage {
pub fn total_tokens(&self) -> u32 {
self.input_tokens
+ self.output_tokens
+ self.cache_read_input_tokens
+ self.cache_creation_input_tokens
self.input_tokens + self.output_tokens
}
}

View File

@@ -142,24 +142,6 @@ impl fmt::Display for MaxMonthlySpendReachedError {
}
}
#[derive(Error, Debug)]
pub struct ModelRequestLimitReachedError {
pub plan: Plan,
}
impl fmt::Display for ModelRequestLimitReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let message = match self.plan {
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
Plan::ZedPro => {
"Model request limit reached. Upgrade to usage-based billing for more requests."
}
};
write!(f, "{message}")
}
}
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);

View File

@@ -53,7 +53,6 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -705,12 +705,12 @@ pub fn map_to_language_model_completion_events(
update_usage(&mut state.usage, &message.usage);
return Some((
vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&state.usage,
))),
Ok(LanguageModelCompletionEvent::StartMessage {
message_id: message.id,
}),
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
&state.usage,
))),
],
state,
));

View File

@@ -16,21 +16,18 @@ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
};
use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::str::FromStr as _;
use std::{
sync::{Arc, LazyLock},
time::Duration,
@@ -38,7 +35,6 @@ use std::{
use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{TintColor, prelude::*};
use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME};
use crate::AllLanguageModelSettings;
use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
@@ -555,32 +551,6 @@ impl CloudLanguageModel {
.is_some()
{
return Err(anyhow!(MaxMonthlySpendReachedError));
} else if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
.is_some()
{
if let Some("model_requests") = response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
.and_then(|resource| resource.to_str().ok())
{
if let Some(plan) = response
.headers()
.get(CURRENT_PLAN_HEADER_NAME)
.and_then(|plan| plan.to_str().ok())
.and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
{
let plan = match plan {
zed_llm_client::Plan::Free => Plan::Free,
zed_llm_client::Plan::ZedPro => Plan::ZedPro,
};
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
}
return Err(anyhow!("Forbidden"));
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:

View File

@@ -4,7 +4,7 @@ use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
FunctionDeclaration, GenerateContentResponse, Part, SystemInstructions, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -405,7 +405,7 @@ pub fn into_google(
.map_or(false, |msg| matches!(msg.role, Role::System))
{
let message = request.messages.remove(0);
Some(SystemInstruction {
Some(SystemInstructions {
parts: map_content(message.content),
})
} else {
@@ -414,7 +414,7 @@ pub fn into_google(
google_ai::GenerateContentRequest {
model,
system_instruction: system_instructions,
system_instructions,
contents: request
.messages
.into_iter()

View File

@@ -61,17 +61,14 @@ impl Anchor {
return Ordering::Equal;
}
let self_excerpt_id = snapshot.latest_excerpt_id(self.excerpt_id);
let other_excerpt_id = snapshot.latest_excerpt_id(other.excerpt_id);
let excerpt_id_cmp = self_excerpt_id.cmp(&other_excerpt_id, snapshot);
let excerpt_id_cmp = self.excerpt_id.cmp(&other.excerpt_id, snapshot);
if excerpt_id_cmp.is_ne() {
return excerpt_id_cmp;
}
if self_excerpt_id == ExcerptId::min() || self_excerpt_id == ExcerptId::max() {
if self.excerpt_id == ExcerptId::min() || self.excerpt_id == ExcerptId::max() {
return Ordering::Equal;
}
if let Some(excerpt) = snapshot.excerpt(self_excerpt_id) {
if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) {
let text_cmp = self.text_anchor.cmp(&other.text_anchor, &excerpt.buffer);
if text_cmp.is_ne() {
return text_cmp;

View File

@@ -5170,7 +5170,6 @@ impl MultiBufferSnapshot {
excerpt_id: ExcerptId,
text_anchor: text::Anchor,
) -> Option<Anchor> {
let excerpt_id = self.latest_excerpt_id(excerpt_id);
let locator = self.excerpt_locator_for_id(excerpt_id);
let mut cursor = self.excerpts.cursor::<Option<&Locator>>(&());
cursor.seek(locator, Bias::Left, &());
@@ -6042,7 +6041,7 @@ impl MultiBufferSnapshot {
return &entry.locator;
}
}
panic!("invalid excerpt id {id:?}")
panic!("invalid excerpt id {:?}", id)
}
}

View File

@@ -746,20 +746,19 @@ fn test_expand_excerpts(cx: &mut App) {
drop(snapshot);
multibuffer.update(cx, |multibuffer, cx| {
let line_zero = multibuffer.snapshot(cx).anchor_before(Point::new(0, 0));
multibuffer.expand_excerpts(
multibuffer.excerpt_ids(),
1,
ExpandExcerptDirection::UpAndDown,
cx,
);
let snapshot = multibuffer.snapshot(cx);
let line_two = snapshot.anchor_before(Point::new(2, 0));
assert_eq!(line_two.cmp(&line_zero, &snapshot), cmp::Ordering::Greater);
)
});
let snapshot = multibuffer.read(cx).snapshot(cx);
// Expanding context lines causes the line containing 'fff' to appear in two different excerpts.
// We don't attempt to merge them, because removing the excerpt could create inconsistency with other layers
// that are tracking excerpt ids.
assert_eq!(
snapshot.text(),
concat!(

View File

@@ -21,10 +21,7 @@ use futures::{
channel::{mpsc, oneshot},
future::{Shared, join_all},
};
use gpui::{
App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, SharedString,
Task,
};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
use http_client::HttpClient;
use language::{BinaryStatus, LanguageRegistry, LanguageToolchainStore};
use lsp::LanguageServerName;
@@ -93,17 +90,6 @@ impl LocalDapStore {
fn next_session_id(&self) -> SessionId {
SessionId(self.next_session_id.fetch_add(1, SeqCst))
}
pub(crate) fn locate_binary(
&self,
mut definition: DebugTaskDefinition,
executor: BackgroundExecutor,
) -> Task<DebugTaskDefinition> {
let locator_store = self.locator_store.clone();
executor.spawn(async move {
let _ = locator_store.resolve_debug_config(&mut definition).await;
definition
})
}
}
pub struct RemoteDapStore {
@@ -349,7 +335,7 @@ impl DapStore {
pub fn new_session(
&mut self,
binary: DebugAdapterBinary,
config: DebugTaskDefinition,
mut config: DebugTaskDefinition,
parent_session: Option<Entity<Session>>,
cx: &mut Context<Self>,
) -> (SessionId, Task<Result<Entity<Session>>>) {
@@ -366,10 +352,22 @@ impl DapStore {
}
let (initialized_tx, initialized_rx) = oneshot::channel();
let locator_store = local_store.locator_store.clone();
let start_debugging_tx = local_store.start_debugging_tx.clone();
let task = cx.spawn(async move |this, cx| {
if config.locator.is_some() {
config = cx
.background_spawn(async move {
locator_store
.resolve_debug_config(&mut config)
.await
.map(|_| config)
})
.await?;
}
let start_client_task = this.update(cx, |this, cx| {
Session::local(
this.breakpoint_store.clone(),

View File

@@ -1,12 +1,13 @@
use super::DapLocator;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use serde_json::Value;
use serde_json::{Value, json};
use smol::{
io::AsyncReadExt,
process::{Command, Stdio},
};
use task::DebugTaskDefinition;
use util::maybe;
pub(super) struct CargoLocator;
@@ -108,13 +109,43 @@ impl DapLocator for CargoLocator {
None
}
};
let Some(executable) = executable.or_else(|| executables.first().cloned()) else {
return Err(anyhow!("Couldn't get executable in cargo locator"));
};
launch_config.program = executable;
if debug_config.adapter == "LLDB" && debug_config.initialize_args.is_none() {
// Find Rust pretty-printers in current toolchain's sysroot
let cwd = launch_config.cwd.clone();
debug_config.initialize_args = maybe!(async move {
let cwd = cwd?;
let output = Command::new("rustc")
.arg("--print")
.arg("sysroot")
.current_dir(cwd)
.output()
.await
.ok()?;
if !output.status.success() {
return None;
}
let sysroot_path = String::from_utf8(output.stdout).ok()?;
let sysroot_path = sysroot_path.trim_end();
let first_command = format!(
r#"command script import "{sysroot_path}/lib/rustlib/etc/lldb_lookup.py"#
);
let second_command =
format!(r#"command source -s 0 '{sysroot_path}/lib/rustlib/etc/lldb_commands"#);
Some(json!({"initCommands": [first_command, second_command]}))
})
.await;
}
launch_config.args.clear();
if let Some(test_name) = test_name {
launch_config.args.push(test_name);

View File

@@ -6233,6 +6233,21 @@ impl LspStore {
})
}
pub fn language_server_state_for_local_buffer<'a>(
&'a self,
buffer: &Buffer,
cx: &mut App,
) -> Option<impl Iterator<Item = &'a LanguageServerState>> {
let local = self.as_local()?;
Some(
local
.language_server_ids_for_buffer(buffer, cx)
.into_iter()
.filter_map(move |server_id| local.language_servers.get(&server_id)),
)
}
pub fn language_servers_for_local_buffer<'a>(
&'a self,
buffer: &Buffer,

View File

@@ -1482,18 +1482,6 @@ impl Project {
.update(cx, |dap_store, cx| dap_store.delegate(&worktree, cx))
})?;
let task = this.update(cx, |project, cx| {
project.dap_store.read(cx).as_local().and_then(|local| {
config.locator.is_some().then(|| {
local.locate_binary(config.clone(), cx.background_executor().clone())
})
})
})?;
let config = if let Some(task) = task {
task.await
} else {
config
};
let binary = adapter
.get_binary(&delegate, &config, user_installed_path, cx)
.await?;

View File

@@ -8273,34 +8273,17 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
json!({
".git": {
"worktrees": {
"some-worktree": {
"commondir": "../..\n"
}
"some-worktree": {}
},
"modules": {
"subdir": {
"some-submodule": {
// For is_git_dir
"HEAD": "",
"config": "",
}
}
}
},
"src": {
"a.txt": "A",
},
"some-worktree": {
".git": "gitdir: ../.git/worktrees/some-worktree\n",
".git": "gitdir: ../.git/worktrees/some-worktree",
"src": {
"b.txt": "B",
}
},
"subdir": {
"some-submodule": {
".git": "gitdir: ../../.git/modules/subdir/some-submodule\n",
"c.txt": "C",
}
}
}),
)
@@ -8332,11 +8315,9 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
[
Path::new(path!("/project")).into(),
Path::new(path!("/project/some-worktree")).into(),
Path::new(path!("/project/subdir/some-submodule")).into(),
]
);
// Generate a git-related event for the worktree and check that it's refreshed.
fs.with_git_state(
path!("/project/some-worktree/.git").as_ref(),
true,
@@ -8378,45 +8359,6 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
StatusCode::Modified.worktree(),
);
});
// The same for the submodule.
fs.with_git_state(
path!("/project/subdir/some-submodule/.git").as_ref(),
true,
|state| {
state.head_contents.insert("c.txt".into(), "c".to_owned());
state.index_contents.insert("c.txt".into(), "c".to_owned());
},
)
.unwrap();
cx.run_until_parked();
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/project/subdir/some-submodule/c.txt"), cx)
})
.await
.unwrap();
let (submodule_repo, barrier) = project.update(cx, |project, cx| {
let (repo, _) = project
.git_store()
.read(cx)
.repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx)
.unwrap();
pretty_assertions::assert_eq!(
repo.read(cx).work_directory_abs_path,
Path::new(path!("/project/subdir/some-submodule")).into(),
);
let barrier = repo.update(cx, |repo, _| repo.barrier());
(repo.clone(), barrier)
});
barrier.await.unwrap();
submodule_repo.update(cx, |repo, _| {
pretty_assertions::assert_eq!(
repo.status_for_path(&"c.txt".into()).unwrap().status,
StatusCode::Modified.worktree(),
);
});
}
#[gpui::test]

View File

@@ -36,7 +36,7 @@ use ui::{
IconWithIndicator, Indicator, PopoverMenu, Tooltip, h_flex, prelude::*,
};
use util::ResultExt;
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::{BottomDockLayout, Workspace, notifications::NotifyResultExt};
use zed_actions::{OpenBrowser, OpenRecent, OpenRemote};
pub use onboarding_banner::restore_banner;
@@ -210,6 +210,7 @@ impl Render for TitleBar {
.pr_1()
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
.children(self.render_call_controls(window, cx))
.child(self.render_bottom_dock_layout_menu(cx))
.map(|el| {
let status = self.client.status();
let status = &*status.borrow();
@@ -622,6 +623,101 @@ impl TitleBar {
}
}
pub fn render_bottom_dock_layout_menu(&self, cx: &mut Context<Self>) -> impl IntoElement {
let workspace = self.workspace.upgrade().unwrap();
let current_layout = workspace.update(cx, |workspace, _cx| workspace.bottom_dock_layout());
PopoverMenu::new("layout-menu")
.trigger(
IconButton::new("toggle_layout", IconName::Layout)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text("Toggle Layout Menu")),
)
.anchor(gpui::Corner::TopRight)
.menu(move |window, cx| {
ContextMenu::build(window, cx, {
let workspace = workspace.clone();
move |menu, _, _| {
menu.label("Bottom Dock")
.separator()
.toggleable_entry(
"Contained",
current_layout == BottomDockLayout::Contained,
ui::IconPosition::End,
None,
{
let workspace = workspace.clone();
move |window, cx| {
workspace.update(cx, |workspace, cx| {
workspace.set_bottom_dock_layout(
BottomDockLayout::Contained,
window,
cx,
);
});
}
},
)
.toggleable_entry(
"Full",
current_layout == BottomDockLayout::Full,
ui::IconPosition::End,
None,
{
let workspace = workspace.clone();
move |window, cx| {
workspace.update(cx, |workspace, cx| {
workspace.set_bottom_dock_layout(
BottomDockLayout::Full,
window,
cx,
);
});
}
},
)
.toggleable_entry(
"Left Aligned",
current_layout == BottomDockLayout::LeftAligned,
ui::IconPosition::End,
None,
{
let workspace = workspace.clone();
move |window, cx| {
workspace.update(cx, |workspace, cx| {
workspace.set_bottom_dock_layout(
BottomDockLayout::LeftAligned,
window,
cx,
);
});
}
},
)
.toggleable_entry(
"Right Aligned",
current_layout == BottomDockLayout::RightAligned,
ui::IconPosition::End,
None,
{
let workspace = workspace.clone();
move |window, cx| {
workspace.update(cx, |workspace, cx| {
workspace.set_bottom_dock_layout(
BottomDockLayout::RightAligned,
window,
cx,
);
});
}
},
)
}
})
.into()
})
}
pub fn render_sign_in_button(&mut self, _: &mut Context<Self>) -> Button {
let client = self.client.clone();
Button::new("sign_in", "Sign in")

View File

@@ -73,11 +73,11 @@ impl Tab {
self
}
pub fn content_height(cx: &App) -> Pixels {
pub fn content_height(cx: &mut App) -> Pixels {
DynamicSpacing::Base32.px(cx) - px(1.)
}
pub fn container_height(cx: &App) -> Pixels {
pub fn container_height(cx: &mut App) -> Pixels {
DynamicSpacing::Base32.px(cx)
}
}

View File

@@ -160,11 +160,7 @@ impl Render for Tooltip {
}),
)
.when_some(self.meta.clone(), |this, meta| {
this.child(
div()
.max_w_72()
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
)
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
})
})
}

View File

@@ -1,20 +0,0 @@
[package]
name = "web_search"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search.rs"
[dependencies]
anyhow.workspace = true
collections.workspace = true
gpui.workspace = true
serde.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,64 +0,0 @@
use anyhow::Result;
use collections::HashMap;
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
use std::sync::Arc;
use zed_llm_client::WebSearchResponse;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| WebSearchRegistry::default());
cx.set_global(GlobalWebSearchRegistry(registry));
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct WebSearchProviderId(pub SharedString);
pub trait WebSearchProvider {
fn id(&self) -> WebSearchProviderId;
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
}
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
impl Global for GlobalWebSearchRegistry {}
#[derive(Default)]
pub struct WebSearchRegistry {
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
active_provider: Option<Arc<dyn WebSearchProvider>>,
}
impl WebSearchRegistry {
pub fn global(cx: &App) -> Entity<Self> {
cx.global::<GlobalWebSearchRegistry>().0.clone()
}
pub fn read_global(cx: &App) -> &Self {
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
}
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
self.providers.values()
}
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
self.active_provider = Some(provider.clone());
self.providers.insert(provider.id(), provider);
}
pub fn register_provider<T: WebSearchProvider + 'static>(
&mut self,
provider: T,
_cx: &mut Context<Self>,
) {
let id = provider.id();
let provider = Arc::new(provider);
self.providers.insert(id.clone(), provider.clone());
if self.active_provider.is_none() {
self.active_provider = Some(provider);
}
}
}

View File

@@ -1,26 +0,0 @@
[package]
name = "web_search_providers"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/web_search_providers.rs"
[dependencies]
anyhow.workspace = true
client.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
language_model.workspace = true
serde.workspace = true
serde_json.workspace = true
web_search.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,103 +0,0 @@
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use client::Client;
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{WebSearchBody, WebSearchResponse};
pub struct CloudWebSearchProvider {
state: Entity<State>,
}
impl CloudWebSearchProvider {
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
let state = cx.new(|cx| State::new(client, cx));
Self { state }
}
}
pub struct State {
client: Arc<Client>,
llm_api_token: LlmApiToken,
_llm_token_subscription: Subscription,
}
impl State {
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
client,
llm_api_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
|this, _, _event, cx| {
let client = this.client.clone();
let llm_api_token = this.llm_api_token.clone();
cx.spawn(async move |_this, _cx| {
llm_api_token.refresh(&client).await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
},
),
}
}
}
impl WebSearchProvider for CloudWebSearchProvider {
fn id(&self) -> WebSearchProviderId {
WebSearchProviderId("zed.dev".into())
}
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
let state = self.state.read(cx);
let client = state.client.clone();
let llm_api_token = state.llm_api_token.clone();
let body = WebSearchBody { query };
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
}
}
async fn perform_web_search(
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
let request_builder = http_client::Request::builder().method(Method::POST);
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
request_builder.uri(web_search_url)
} else {
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
};
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)
.await
.context("failed to send web search request")?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"error performing web search.\nStatus: {:?}\nBody: {body}",
response.status(),
));
}
}

View File

@@ -1,35 +0,0 @@
mod cloud;
use client::Client;
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
use gpui::{App, Context};
use std::sync::Arc;
use web_search::WebSearchRegistry;
pub fn init(client: Arc<Client>, cx: &mut App) {
let registry = WebSearchRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_web_search_providers(registry, client, cx);
});
}
fn register_web_search_providers(
_registry: &mut WebSearchRegistry,
client: Arc<Client>,
cx: &mut Context<WebSearchRegistry>,
) {
cx.observe_flag::<ZedProWebSearchTool, _>({
let client = client.clone();
move |is_enabled, cx| {
if is_enabled {
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
registry.register_provider(
cloud::CloudWebSearchProvider::new(client.clone(), cx),
cx,
);
});
}
}
})
.detach();
}

View File

@@ -106,7 +106,6 @@ use uuid::Uuid;
pub use workspace_settings::{
AutosaveSetting, BottomDockLayout, RestoreOnStartupBehavior, TabBarSettings, WorkspaceSettings,
};
use zed_actions::feedback::FileBugReport;
use crate::notifications::NotificationId;
use crate::persistence::{
@@ -5396,6 +5395,8 @@ enum ActivateInDirectionTarget {
}
fn notify_if_database_failed(workspace: WindowHandle<Workspace>, cx: &mut AsyncApp) {
const REPORT_ISSUE_URL: &str = "https://github.com/zed-industries/zed/issues/new?assignees=&labels=admin+read%2Ctriage%2Cbug&projects=&template=1_bug_report.yml";
workspace
.update(cx, |workspace, _, cx| {
if (*db::ALL_FILE_DB_FAILED).load(std::sync::atomic::Ordering::Acquire) {
@@ -5409,9 +5410,7 @@ fn notify_if_database_failed(workspace: WindowHandle<Workspace>, cx: &mut AsyncA
MessageNotification::new("Failed to load the database file.", cx)
.primary_message("File an Issue")
.primary_icon(IconName::Plus)
.primary_on_click(|window, cx| {
window.dispatch_action(Box::new(FileBugReport), cx)
})
.primary_on_click(|_window, cx| cx.open_url(REPORT_ISSUE_URL))
})
},
);

View File

@@ -3120,12 +3120,32 @@ impl BackgroundScannerState {
.as_path()
.into();
let (repository_dir_abs_path, common_dir_abs_path) =
discover_git_paths(&dot_git_abs_path, fs);
let mut common_dir_abs_path = dot_git_abs_path.clone();
let mut repository_dir_abs_path = dot_git_abs_path.clone();
// Parse .git if it's a "gitfile" pointing to a repository directory elsewhere.
if let Some(dot_git_contents) = smol::block_on(fs.load(&dot_git_abs_path)).ok() {
if let Some(path) = dot_git_contents.strip_prefix("gitdir:") {
let path = path.trim();
let path = dot_git_abs_path
.parent()
.unwrap_or(Path::new(""))
.join(path);
if let Some(path) = smol::block_on(fs.canonicalize(&path)).log_err() {
repository_dir_abs_path = Path::new(&path).into();
common_dir_abs_path = repository_dir_abs_path.clone();
if let Some(ancestor_dot_git) = path
.ancestors()
.skip(1)
.find(|ancestor| smol::block_on(is_git_dir(ancestor, fs)))
{
common_dir_abs_path = ancestor_dot_git.into();
}
}
} else {
log::error!("failed to parse contents of .git file: {dot_git_contents:?}");
}
};
watcher.add(&common_dir_abs_path).log_err();
if !repository_dir_abs_path.starts_with(&common_dir_abs_path) {
watcher.add(&repository_dir_abs_path).log_err();
}
let work_directory_id = work_dir_entry.id;
@@ -5488,40 +5508,3 @@ impl CreatedEntry {
}
}
}
fn parse_gitfile(content: &str) -> anyhow::Result<&Path> {
let path = content
.strip_prefix("gitdir:")
.ok_or_else(|| anyhow!("failed to parse gitfile content {content:?}"))?;
Ok(Path::new(path.trim()))
}
fn discover_git_paths(dot_git_abs_path: &Arc<Path>, fs: &dyn Fs) -> (Arc<Path>, Arc<Path>) {
let mut repository_dir_abs_path = dot_git_abs_path.clone();
let mut common_dir_abs_path = dot_git_abs_path.clone();
if let Some(path) = smol::block_on(fs.load(&dot_git_abs_path))
.ok()
.as_ref()
.and_then(|contents| parse_gitfile(contents).log_err())
{
let path = dot_git_abs_path
.parent()
.unwrap_or(Path::new(""))
.join(path);
if let Some(path) = smol::block_on(fs.canonicalize(&path)).log_err() {
repository_dir_abs_path = Path::new(&path).into();
common_dir_abs_path = repository_dir_abs_path.clone();
if let Some(commondir_contents) = smol::block_on(fs.load(&path.join("commondir"))).ok()
{
if let Some(commondir_path) =
smol::block_on(fs.canonicalize(&path.join(commondir_contents.trim()))).log_err()
{
common_dir_abs_path = commondir_path.as_path().into();
}
}
}
};
(repository_dir_abs_path, common_dir_abs_path)
}

View File

@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition.workspace = true
name = "zed"
version = "0.183.3"
version = "0.183.0"
publish.workspace = true
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]
@@ -133,8 +133,6 @@ util.workspace = true
uuid.workspace = true
vim.workspace = true
vim_mode_setting.workspace = true
web_search.workspace = true
web_search_providers.workspace = true
welcome.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -1 +1 @@
preview
dev

View File

@@ -490,8 +490,6 @@ fn main() {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
inline_completion_registry::init(
app_state.client.clone(),

View File

@@ -4258,8 +4258,6 @@ mod tests {
app_state.fs.clone(),
cx,
);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
assistant::init(
app_state.fs.clone(),

View File

@@ -150,7 +150,7 @@ pub mod command_palette {
pub mod feedback {
use gpui::actions;
actions!(feedback, [FileBugReport, GiveFeedback]);
actions!(feedback, [GiveFeedback]);
}
pub mod theme_selector {

View File

@@ -75,46 +75,6 @@ Non-negative `float` values
`float` values
## Bottom Dock Layout
- Description: Control the layout of the bottom dock, relative to the left and right docks
- Setting: `bottom_dock_layout`
- Default: `"contained"`
**Options**
1. Contain the bottom dock, giving the full height of the window to the left and right docks
```json
{
"bottom_dock_layout": "contained"
}
```
2. Give the bottom dock the full width of the window, truncating the left and right docks
```json
{
"bottom_dock_layout": "full"
}
```
3. Left align the bottom dock, truncating the left dock and giving the right dock the full height of the window
```json
{
"bottom_dock_layout": "left_aligned"
}
```
3. Right align the bottom dock, giving the left dock the full height of the window and truncating the right dock.
```json
{
"bottom_dock_layout": "right_aligned"
}
```
## Auto Install extensions
- Description: Define extensions to be autoinstalled or never be installed.
@@ -1351,10 +1311,10 @@ To interpret all `.c` files as C++, files called `MyLockFile` as TOML and files
"include_warnings": true,
"inline": {
"enabled": false
},
}
"update_with_cursor": false,
"primary_only": false,
"use_rendered": false
"use_rendered": false,
}
}
```

View File

@@ -291,15 +291,14 @@ To run tests in your Ruby project, you can set up custom tasks in your local `.z
```json
[
{
"label": "test $ZED_RELATIVE_FILE -n /$ZED_SYMBOL/",
"command": "bin/rails test $ZED_RELATIVE_FILE -n /$ZED_SYMBOL/",
"label": "test $ZED_RELATIVE_FILE:$ZED_ROW",
"command": "bin/rails",
"args": ["test", "\"$ZED_RELATIVE_FILE:$ZED_ROW\""],
"tags": ["ruby-test"]
}
]
```
Note: We can't use `args` here because of the way quotes are handled.
### Minitest
Plain minitest does not support running tests by line number, only by name, so we need to use `$ZED_SYMBOL` instead:

Some files were not shown because too many files have changed in this diff Show More