Compare commits
14 Commits
workflow-p
...
v0.148.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
623a463777 | ||
|
|
afb3fbff24 | ||
|
|
47c9b1ef70 | ||
|
|
ccd8a9af89 | ||
|
|
1dfc2fe1fa | ||
|
|
fb449399fc | ||
|
|
5ec8cdcb3c | ||
|
|
371b828d28 | ||
|
|
fdb5c7fbd3 | ||
|
|
84d68660a3 | ||
|
|
7e44cd04aa | ||
|
|
7b8a87b61c | ||
|
|
79e5ea7210 | ||
|
|
822a4ccb6b |
10
.github/workflows/ci.yml
vendored
@@ -231,20 +231,20 @@ jobs:
|
||||
mv target/x86_64-apple-darwin/release/Zed.dmg target/x86_64-apple-darwin/release/Zed-x86_64.dmg
|
||||
|
||||
- name: Upload app bundle (universal) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4
|
||||
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}.dmg
|
||||
path: target/release/Zed.dmg
|
||||
- name: Upload app bundle (aarch64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4
|
||||
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-aarch64.dmg
|
||||
path: target/aarch64-apple-darwin/release/Zed-aarch64.dmg
|
||||
|
||||
- name: Upload app bundle (x86_64) to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4
|
||||
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: Zed_${{ github.event.pull_request.head.sha || github.sha }}-x86_64.dmg
|
||||
@@ -319,7 +319,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4
|
||||
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz
|
||||
@@ -403,7 +403,7 @@ jobs:
|
||||
run: script/bundle-linux
|
||||
|
||||
- name: Upload Linux bundle to workflow run if main branch or specific label
|
||||
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4
|
||||
uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4
|
||||
if: ${{ github.ref == 'refs/heads/main' }} || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
with:
|
||||
name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz
|
||||
|
||||
2297
Cargo.lock
generated
@@ -1,2 +0,0 @@
|
||||
app: postgrest crates/collab/postgrest_app.conf
|
||||
llm: postgrest crates/collab/postgrest_llm.conf
|
||||
@@ -1,11 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_1896_18)">
|
||||
<path d="M11.094 3.09999H8.952L12.858 12.9H15L11.094 3.09999Z" fill="#1F1F1E"/>
|
||||
<path d="M4.906 3.09999L1 12.9H3.184L3.98284 10.842H8.06915L8.868 12.9H11.052L7.146 3.09999H4.906ZM4.68928 9.02199L6.026 5.57799L7.3627 9.02199H4.68928Z" fill="#1F1F1E"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_1896_18">
|
||||
<rect width="14" height="9.8" fill="white" transform="translate(1 3.09999)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
<path d="M3.43331 10.1846L6.66616 2.33334L9.89902 10.1846M3.43331 10.1846L1.9995 13.6667M3.43331 10.1846H9.89902M11.3328 13.6667L9.89902 10.1846" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M14.0613 13.647L9.34721 2.33334" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 530 B After Width: | Height: | Size: 459 B |
@@ -1,3 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M13.8695 8.14262C13.8695 11.6221 11.4867 14.0984 7.96785 14.0984C4.59408 14.0984 1.86949 11.3738 1.86949 7.99999C1.86949 4.62622 4.59408 1.90163 7.96785 1.90163C9.61048 1.90163 10.9924 2.50409 12.0572 3.49754L10.3974 5.09344C8.22605 2.99836 4.18834 4.57213 4.18834 7.99999C4.18834 10.127 5.88752 11.8508 7.96785 11.8508C10.3826 11.8508 11.2875 10.1197 11.4301 9.22213H7.96785V7.12458H13.7736C13.8301 7.43688 13.8695 7.73688 13.8695 8.14262Z" fill="black"/>
|
||||
<path d="M14.8695 8.16639C14.8695 12.2258 12.0896 15.1147 7.98425 15.1147C4.04818 15.1147 0.869492 11.9361 0.869492 7.99999C0.869492 4.06393 4.04818 0.885239 7.98425 0.885239C9.90064 0.885239 11.5129 1.58811 12.7551 2.74712L10.8187 4.60901C8.28547 2.16475 3.57482 4.00081 3.57482 7.99999C3.57482 10.4816 5.5572 12.4926 7.98425 12.4926C10.8015 12.4926 11.8572 10.4729 12.0236 9.42581H7.98425V6.97868H14.7576C14.8236 7.34303 14.8695 7.69303 14.8695 8.16639Z" fill="black"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 569 B After Width: | Height: | Size: 575 B |
|
Before Width: | Height: | Size: 8.5 KiB After Width: | Height: | Size: 1.1 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-microscope"><path d="M6 18h8"/><path d="M3 22h18"/><path d="M14 22a7 7 0 1 0 0-14h-1"/><path d="M9 14h2"/><path d="M9 12a2 2 0 0 1-2-2V6h6v4a2 2 0 0 1-2 2Z"/><path d="M12 6V3a1 1 0 0 0-1-1H9a1 1 0 0 0-1 1v3"/></svg>
|
||||
|
Before Width: | Height: | Size: 418 B |
@@ -1 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-undo"><path d="M3 7v6h6"/><path d="M21 17a9 9 0 0 0-9-9 9 9 0 0 0-6 2.3L3 13"/></svg>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-text-search"><path d="M21 6H3"/><path d="M10 12H3"/><path d="M10 18H3"/><circle cx="17" cy="15" r="3"/><path d="m21 19-1.9-1.9"/></svg>
|
||||
|
Before Width: | Height: | Size: 288 B After Width: | Height: | Size: 338 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-text-select"><path d="M5 3a2 2 0 0 0-2 2"/><path d="M19 3a2 2 0 0 1 2 2"/><path d="M21 19a2 2 0 0 1-2 2"/><path d="M5 21a2 2 0 0 1-2-2"/><path d="M9 3h1"/><path d="M9 21h1"/><path d="M14 3h1"/><path d="M14 21h1"/><path d="M3 9v1"/><path d="M21 9v1"/><path d="M3 14v1"/><path d="M21 14v1"/><line x1="7" x2="15" y1="8" y2="8"/><line x1="7" x2="17" y1="12" y2="12"/><line x1="7" x2="13" y1="16" y2="16"/></svg>
|
||||
|
Before Width: | Height: | Size: 610 B |
@@ -437,7 +437,7 @@
|
||||
"context": "Editor && showing_completions",
|
||||
"bindings": {
|
||||
"enter": "editor::ConfirmCompletion",
|
||||
"tab": "editor::ComposeCompletion"
|
||||
"tab": "editor::ConfirmCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -89,9 +89,9 @@
|
||||
"alt-right": "editor::MoveToNextWordEnd",
|
||||
"alt-f": "editor::MoveToNextWordEnd",
|
||||
"cmd-left": "editor::MoveToBeginningOfLine",
|
||||
"ctrl-a": ["editor::MoveToBeginningOfLine", { "stop_at_soft_wraps": false }],
|
||||
"ctrl-a": "editor::MoveToBeginningOfLine",
|
||||
"cmd-right": "editor::MoveToEndOfLine",
|
||||
"ctrl-e": ["editor::MoveToEndOfLine", { "stop_at_soft_wraps": false }],
|
||||
"ctrl-e": "editor::MoveToEndOfLine",
|
||||
"cmd-up": "editor::MoveToBeginning",
|
||||
"cmd-down": "editor::MoveToEnd",
|
||||
"shift-up": "editor::SelectUp",
|
||||
@@ -474,7 +474,7 @@
|
||||
"context": "Editor && showing_completions",
|
||||
"bindings": {
|
||||
"enter": "editor::ConfirmCompletion",
|
||||
"tab": "editor::ComposeCompletion"
|
||||
"tab": "editor::ConfirmCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"bindings": {
|
||||
"i": ["vim::PushOperator", { "Object": { "around": false } }],
|
||||
"a": ["vim::PushOperator", { "Object": { "around": true } }],
|
||||
":": "command_palette::Toggle",
|
||||
"h": "vim::Left",
|
||||
"left": "vim::Left",
|
||||
"backspace": "vim::Backspace",
|
||||
@@ -198,12 +199,17 @@
|
||||
"ctrl-6": "pane::AlternateFile"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "VimControl && VimCount",
|
||||
"bindings": {
|
||||
"0": ["vim::Number", 0]
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_mode == normal",
|
||||
"bindings": {
|
||||
"escape": "editor::Cancel",
|
||||
"ctrl-[": "editor::Cancel",
|
||||
":": "command_palette::Toggle",
|
||||
".": "vim::Repeat",
|
||||
"c": ["vim::PushOperator", "Change"],
|
||||
"shift-c": "vim::ChangeToEndOfLine",
|
||||
@@ -251,17 +257,9 @@
|
||||
"g c": ["vim::PushOperator", "ToggleComments"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "VimControl && VimCount",
|
||||
"bindings": {
|
||||
"0": ["vim::Number", 0],
|
||||
":": "vim::CountCommand"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_mode == visual",
|
||||
"bindings": {
|
||||
":": "vim::VisualCommand",
|
||||
"u": "vim::ConvertToLowerCase",
|
||||
"U": "vim::ConvertToUpperCase",
|
||||
"o": "vim::OtherEnd",
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{{#if language_name}}
|
||||
File language: {{language_name}}
|
||||
Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
|
||||
{{else}}
|
||||
Here's a file of text that I'm going to ask you to make an edit to.
|
||||
{{/if}}
|
||||
|
||||
{{#if is_insert}}
|
||||
The point you'll need to insert at is marked with <insert_here></insert_here>.
|
||||
{{else}}
|
||||
The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
|
||||
{{/if}}
|
||||
|
||||
<document>
|
||||
@@ -7,35 +15,47 @@ File language: {{language_name}}
|
||||
</document>
|
||||
|
||||
{{#if is_truncated}}
|
||||
Note: Context around the relevant section has been truncated for brevity.
|
||||
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
|
||||
{{/if}}
|
||||
|
||||
Editing instructions:
|
||||
1. Rewrite the section marked with <rewrite_this></rewrite_this> tags based on this prompt:
|
||||
{{#if is_insert}}
|
||||
You can't replace {{content_type}}, your answer will be inserted in place of the `<insert_here></insert_here>` tags. Don't include the insert_here tags in your output.
|
||||
|
||||
Generate {{content_type}} based on the following prompt:
|
||||
|
||||
<prompt>
|
||||
{{{user_prompt}}}
|
||||
</prompt>
|
||||
|
||||
2. Within <rewrite_this></rewrite_this>, make changes only in these subsections:
|
||||
{{#if has_insertion}}
|
||||
- Insert text where marked with <insert_here></insert_here> tags
|
||||
{{/if}}
|
||||
{{#if has_replacement}}
|
||||
- Edit text surrounded by <edit_here></edit_here> tags
|
||||
{{/if}}
|
||||
Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
|
||||
|
||||
Immediately start with the following format with no remarks:
|
||||
|
||||
```
|
||||
{{INSERTED_CODE}}
|
||||
```
|
||||
{{else}}
|
||||
Edit the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
|
||||
|
||||
<prompt>
|
||||
{{{user_prompt}}}
|
||||
</prompt>
|
||||
|
||||
{{#if rewrite_section}}
|
||||
And here's the section to rewrite based on that prompt again for reference:
|
||||
|
||||
3. Section to edit:
|
||||
<rewrite_this>
|
||||
{{{rewrite_section_with_selections}}}
|
||||
{{{rewrite_section}}}
|
||||
</rewrite_this>
|
||||
{{/if}}
|
||||
|
||||
4. Guidelines:
|
||||
- Only make changes necessary to fulfill the prompt
|
||||
- Preserve all surrounding {{content_type}}
|
||||
- Maintain the original indentation level
|
||||
- Rewrite the entire section, even if no changes are needed
|
||||
- Do not include <rewrite_this>, <insert_here>, or <edit_here> tags in your output
|
||||
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
|
||||
|
||||
Output format:
|
||||
Immediately start with the following, ensuring no leading whitespace:
|
||||
```{{REWRITTEN_CODE}}```
|
||||
Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
|
||||
|
||||
Immediately start with the following format with no remarks:
|
||||
|
||||
```
|
||||
{{REWRITTEN_CODE}}
|
||||
```
|
||||
{{/if}}
|
||||
|
||||
@@ -1,295 +1,87 @@
|
||||
<workflow>
|
||||
Guide the user through code changes in numbered steps that focus on individual functions, type definitions, etc.
|
||||
Surround each distinct step in a <step></step> XML tag.
|
||||
|
||||
<instructions>
|
||||
- Use the language of the file for code fence blocks unless otherwise specified.
|
||||
- Include a code or file action in each step.
|
||||
- Provide error handling and input validation where appropriate.
|
||||
- Adapt explanations based on the user's perceived level of expertise.
|
||||
- Include comments in code examples to enhance understanding.
|
||||
- Consider more complex programming scenarios when relevant.
|
||||
</instructions>
|
||||
Here's an example of a user-assistant dialog with step numbers and a follow-up question in Rust:
|
||||
|
||||
Examples of user-assistant dialogs with step numbers and user follow-up requests:
|
||||
|
||||
<example language="rust">
|
||||
<turn number="1">
|
||||
<user>Create a Library struct with a vector of Book structs. Include methods to add books and find books by author. Use generics for the book ID.</user>
|
||||
<assistant>Certainly! Let's create a Library struct with a vector of Book structs, including methods to add books and find books by author, using generics for the book ID:
|
||||
|
||||
<step>Define the Book and Library structs
|
||||
1. User: I want to create a User struct with name and age fields, and a method to print the user's information. Can you help me with that?
|
||||
A: Certainly! Let's create a User struct with the specified fields and a method to print the information:
|
||||
<step>Define the User struct
|
||||
```rust
|
||||
use std::fmt::Display;
|
||||
|
||||
struct Book<T>
|
||||
where
|
||||
T: Display + Clone
|
||||
{
|
||||
id: T,
|
||||
title: String,
|
||||
author: String,
|
||||
}
|
||||
|
||||
struct Library<T>
|
||||
where
|
||||
T: Display + Clone
|
||||
{
|
||||
books: Vec<Book<T>>,
|
||||
struct User {
|
||||
name: String,
|
||||
age: u32,
|
||||
}
|
||||
```
|
||||
</step>
|
||||
<step>Implement methods for the Library struct
|
||||
<step>Implement methods for the User struct
|
||||
```rust
|
||||
impl<T> Library<T>
|
||||
where
|
||||
T: Display + Clone
|
||||
{
|
||||
fn new() -> Self {
|
||||
Library { books: Vec::new() }
|
||||
impl User {
|
||||
fn new(name: String, age: u32) -> Self {
|
||||
User { name, age }
|
||||
}
|
||||
|
||||
fn add_book(&mut self, id: T, title: String, author: String) {
|
||||
let book = Book { id, title, author };
|
||||
self.books.push(book);
|
||||
}
|
||||
|
||||
fn find_books_by_author(&self, author: &str) -> Vec<&Book<T>> {
|
||||
self.books
|
||||
.iter()
|
||||
.filter(|book| book.author.to_lowercase() == author.to_lowercase())
|
||||
.collect()
|
||||
fn print_info(&self) {
|
||||
println!("Name: {}, Age: {}", self.name, self.age);
|
||||
}
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
This code defines:
|
||||
1. A generic Book struct with an ID of type T, title, and author
|
||||
2. A Library struct containing a vector of Books
|
||||
3. Methods to create a new Library, add books, and find books by author
|
||||
The use of generics allows flexibility in the type of ID used for books.</assistant>
|
||||
</turn>
|
||||
This code defines a User struct and implements two methods:
|
||||
- `new`: a constructor to create a new User
|
||||
- `print_info`: a method to print the user's information
|
||||
|
||||
<turn number="2">
|
||||
<user>How can I use this Library struct in a main function?</user>
|
||||
<assistant>Here's how you can use the Library struct in a main function:
|
||||
2. User: Great! Now, how can I create a User instance and call the print_info method?
|
||||
A: You can create a User instance and call the print_info method like this:
|
||||
|
||||
<step>Create a main function to demonstrate Library usage
|
||||
<step>Create a User instance and call print_info
|
||||
```rust
|
||||
fn main() {
|
||||
// Create a new Library with string IDs
|
||||
let mut library = Library::new();
|
||||
|
||||
// Add some books
|
||||
library.add_book(String::from("001"), String::from("The Rust Programming Language"), String::from("Steve Klabnik"));
|
||||
library.add_book(String::from("002"), String::from("Programming Rust"), String::from("Jim Blandy"));
|
||||
library.add_book(String::from("003"), String::from("Rust in Action"), String::from("Tim McNamara"));
|
||||
|
||||
// Find books by author
|
||||
let author = "Steve Klabnik";
|
||||
let books = library.find_books_by_author(author);
|
||||
|
||||
println!("Books by {}:", author);
|
||||
for book in books {
|
||||
println!("ID: {}, Title: {}", book.id, book.title);
|
||||
}
|
||||
let user = User::new(String::from("Alice"), 30);
|
||||
user.print_info();
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
This main function demonstrates:
|
||||
1. Creating a new Library instance
|
||||
2. Adding books with string IDs
|
||||
3. Finding and displaying books by a specific author</assistant>
|
||||
</turn>
|
||||
This will output: `Name: Alice, Age: 30`
|
||||
|
||||
<turn number="3">
|
||||
<user>Can you add a method to remove a book by ID and handle potential errors?</user>
|
||||
<assistant>Certainly! Let's add a method to remove a book by ID and handle potential errors:
|
||||
3. User: Can you add an email field to the User struct and update the print_info method to include it?
|
||||
A: Certainly! Let's modify the User struct and update the print_info method:
|
||||
|
||||
<step>Add a remove_book method to the Library struct
|
||||
<step>Add email field to User struct
|
||||
```rust
|
||||
use std::fmt::Display;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BookNotFoundError;
|
||||
|
||||
impl Error for BookNotFoundError {}
|
||||
|
||||
impl fmt::Display for BookNotFoundError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Book not found")
|
||||
}
|
||||
struct User {
|
||||
name: String,
|
||||
age: u32,
|
||||
email: String,
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
impl<T> Library<T>
|
||||
where
|
||||
T: Display + Clone + PartialEq
|
||||
{
|
||||
// ... (previous methods remain the same)
|
||||
|
||||
fn remove_book(&mut self, id: &T) -> Result<Book<T>, BookNotFoundError> {
|
||||
if let Some(index) = self.books.iter().position(|book| &book.id == id) {
|
||||
Ok(self.books.remove(index))
|
||||
} else {
|
||||
Err(BookNotFoundError)
|
||||
}
|
||||
<step>Update the new method and print_info method
|
||||
```rust
|
||||
impl User {
|
||||
fn new(name: String, age: u32, email: String) -> Self {
|
||||
User { name, age, email }
|
||||
}
|
||||
fn print_info(&self) {
|
||||
println!("Name: {}, Age: {}, Email: {}", self.name, self.age, self.email);
|
||||
}
|
||||
}
|
||||
```
|
||||
</step>
|
||||
<step>Update the main function to demonstrate book removal
|
||||
|
||||
<step>Update the main function
|
||||
```rust
|
||||
fn main() {
|
||||
// ... (previous code remains the same)
|
||||
|
||||
// Attempt to remove a book
|
||||
let id_to_remove = String::from("002");
|
||||
match library.remove_book(&id_to_remove) {
|
||||
Ok(removed_book) => println!("Removed book: {} by {}", removed_book.title, removed_book.author),
|
||||
Err(e) => println!("Error removing book: {}", e),
|
||||
}
|
||||
|
||||
// Attempt to remove a non-existent book
|
||||
let non_existent_id = String::from("999");
|
||||
match library.remove_book(&non_existent_id) {
|
||||
Ok(_) => println!("Book removed successfully"),
|
||||
Err(e) => println!("Error removing book: {}", e),
|
||||
}
|
||||
|
||||
// Display remaining books
|
||||
println!("\nRemaining books:");
|
||||
for book in &library.books {
|
||||
println!("ID: {}, Title: {} by {}", book.id, book.title, book.author);
|
||||
}
|
||||
let user = User::new(String::from("Alice"), 30, String::from("alice@example.com"));
|
||||
user.print_info();
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
This update includes:
|
||||
1. A new remove_book method that returns a Result type
|
||||
2. A custom BookNotFoundError for error handling
|
||||
3. Demonstration of successful book removal and error handling for non-existent books
|
||||
4. Display of remaining books after removal attempts
|
||||
|
||||
The remove_book method uses the PartialEq trait to compare book IDs, allowing for flexible ID types while maintaining type safety.</assistant>
|
||||
</turn>
|
||||
</example>
|
||||
|
||||
<example language="go">
|
||||
<turn number="1">
|
||||
<user>Create a Person struct with name, age, and email fields, and a method to greet. Include error handling for age.</user>
|
||||
<assistant>Certainly! Let's create a Person struct with the specified fields and a method to greet in Go:
|
||||
|
||||
<step>Define the Person struct and NewPerson function with error handling
|
||||
```go
|
||||
type Person struct {
|
||||
Name string
|
||||
Age int
|
||||
Email string
|
||||
}
|
||||
|
||||
func NewPerson(name string, age int, email string) (*Person, error) {
|
||||
if age < 0 || age > 150 {
|
||||
return nil, fmt.Errorf("invalid age: %d", age)
|
||||
}
|
||||
return &Person{Name: name, Age: age, Email: email}, nil
|
||||
}
|
||||
```
|
||||
|
||||
Implement the Greet method for the Person struct
|
||||
```go
|
||||
func (p *Person) Greet() string {
|
||||
return fmt.Sprintf("Hello, I'm %s, %d years old. You can reach me at %s.", p.Name, p.Age, p.Email)
|
||||
}
|
||||
```
|
||||
|
||||
This code defines a Person struct, a NewPerson function with age validation, and a Greet method.</assistant>
|
||||
</turn>
|
||||
|
||||
<turn number="2">
|
||||
<user>How can I use these in a main function?</user>
|
||||
<assistant>Here's how to use the Person struct and its methods in a main function:
|
||||
|
||||
<step>Create a main function with error handling
|
||||
```go
|
||||
func main() {
|
||||
// Create a new Person instance
|
||||
person, err := NewPerson("Alice", 30, "alice@example.com")
|
||||
if err != nil {
|
||||
fmt.Println("Error creating person:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use the Greet method
|
||||
fmt.Println(person.Greet())
|
||||
|
||||
// Demonstrate error handling with invalid age
|
||||
invalidPerson, err := NewPerson("Bob", 200, "bob@example.com")
|
||||
if err != nil {
|
||||
fmt.Println("Error creating person with invalid age:", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
This main function demonstrates:
|
||||
1. Creating a valid Person instance
|
||||
2. Using the Greet method
|
||||
3. Handling errors when creating a Person with an invalid age</assistant>
|
||||
</turn>
|
||||
|
||||
<turn number="3">
|
||||
<user>Can you add a method to update the email with validation?</user>
|
||||
<assistant>Certainly! Let's add a method to update the email with validation:
|
||||
|
||||
<step>Add UpdateEmail method with validation
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
func (p *Person) UpdateEmail(newEmail string) error {
|
||||
// Simple email validation using regex
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
if !emailRegex.MatchString(newEmail) {
|
||||
return fmt.Errorf("invalid email format: %s", newEmail)
|
||||
}
|
||||
p.Email = newEmail
|
||||
return nil
|
||||
}
|
||||
```
|
||||
</step>
|
||||
<step>Update the main function to demonstrate email update
|
||||
```go
|
||||
func main() {
|
||||
// ... (previous code remains the same)
|
||||
|
||||
// Demonstrate email update
|
||||
err = person.UpdateEmail("alice.new@example.com")
|
||||
if err != nil {
|
||||
fmt.Println("Error updating email:", err)
|
||||
} else {
|
||||
fmt.Println("Updated greeting:", person.Greet())
|
||||
}
|
||||
|
||||
// Demonstrate email update with invalid email
|
||||
err = person.UpdateEmail("invalid-email")
|
||||
if err != nil {
|
||||
fmt.Println("Error updating email:", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
</step>
|
||||
|
||||
This update includes:
|
||||
1. An UpdateEmail method with email format validation
|
||||
2. Demonstration of successful email update in the main function
|
||||
3. Handling of invalid email update attempt</assistant>
|
||||
</turn>
|
||||
</example>
|
||||
</workflow>
|
||||
This will now output: `Name: Alice, Age: 30, Email: alice@example.com`
|
||||
The changes include:
|
||||
1. Adding the `email` field to the User struct
|
||||
2. Updating the `new` method to accept an email parameter
|
||||
3. Modifying the `print_info` method to include the email
|
||||
4. Updating the main function to provide an email when creating a User instance
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Your task is to map a step from the conversation above to suggestions on symbols inside the provided source files.
|
||||
Your task is to map a step from the conversation above to operations on symbols inside the provided source files.
|
||||
|
||||
Guidelines:
|
||||
- There's no need to describe *what* to do, just *where* to do it.
|
||||
@@ -6,13 +6,13 @@ Guidelines:
|
||||
- Don't create and then update a file.
|
||||
- We'll create it in one shot.
|
||||
- Prefer updating symbols lower in the syntax tree if possible.
|
||||
- Never include suggestions on a parent symbol and one of its children in the same suggestions block.
|
||||
- Never nest an operation with another operation or include CDATA or other content. All suggestions are leaf nodes.
|
||||
- Never include operations on a parent symbol and one of its children in the same operations block.
|
||||
- Never nest an operation with another operation or include CDATA or other content. All operations are leaf nodes.
|
||||
- Include a description attribute for each operation with a brief, one-line description of the change to perform.
|
||||
- Descriptions are required for all suggestions except delete.
|
||||
- When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
|
||||
- Descriptions are required for all operations except delete.
|
||||
- When generating multiple operations, ensure the descriptions are specific to each individual operation.
|
||||
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
|
||||
- Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
|
||||
- Don't generate multiple operations at the same location. Instead, combine them together in a single operation with a succinct combined description.
|
||||
|
||||
Example 1:
|
||||
|
||||
@@ -33,12 +33,12 @@ impl Rectangle {
|
||||
<step>Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct</step>
|
||||
<step>Implement the 'Display' trait for the Rectangle struct</step>
|
||||
|
||||
What are the suggestions for the step: <step>Add a new method 'calculate_area' to the Rectangle struct</step>
|
||||
What are the operations for the step: <step>Add a new method 'calculate_area' to the Rectangle struct</step>
|
||||
|
||||
A (wrong):
|
||||
{
|
||||
"title": "Add Rectangle methods",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/shapes.rs",
|
||||
@@ -59,7 +59,7 @@ This demonstrates what NOT to do. NEVER append multiple children at the same loc
|
||||
A (corrected):
|
||||
{
|
||||
"title": "Add Rectangle methods",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "AppendChild",
|
||||
"path": "src/shapes.rs",
|
||||
@@ -70,12 +70,12 @@ A (corrected):
|
||||
}
|
||||
|
||||
User:
|
||||
What are the suggestions for the step: <step>Implement the 'Display' trait for the Rectangle struct</step>
|
||||
What are the operations for the step: <step>Implement the 'Display' trait for the Rectangle struct</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Implement Display for Rectangle",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/shapes.rs",
|
||||
@@ -109,12 +109,12 @@ impl User {
|
||||
<step>Update the 'print_info' method to use formatted output</step>
|
||||
<step>Remove the 'email' field from the User struct</step>
|
||||
|
||||
What are the suggestions for the step: <step>Update the 'print_info' method to use formatted output</step>
|
||||
What are the operations for the step: <step>Update the 'print_info' method to use formatted output</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Use formatted output",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/user.rs",
|
||||
@@ -125,12 +125,12 @@ A:
|
||||
}
|
||||
|
||||
User:
|
||||
What are the suggestions for the step: <step>Remove the 'email' field from the User struct</step>
|
||||
What are the operations for the step: <step>Remove the 'email' field from the User struct</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Remove email field",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "Delete",
|
||||
"path": "src/user.rs",
|
||||
@@ -163,12 +163,12 @@ impl Vehicle {
|
||||
<step>Add a 'use std::fmt;' statement at the beginning of the file</step>
|
||||
<step>Add a new method 'start_engine' in the Vehicle impl block</step>
|
||||
|
||||
What are the suggestions for the step: <step>Add a 'use std::fmt;' statement at the beginning of the file</step>
|
||||
What are the operations for the step: <step>Add a 'use std::fmt;' statement at the beginning of the file</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Add use std::fmt statement",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/vehicle.rs",
|
||||
@@ -178,12 +178,12 @@ A:
|
||||
}
|
||||
|
||||
User:
|
||||
What are the suggestions for the step: <step>Add a new method 'start_engine' in the Vehicle impl block</step>
|
||||
What are the operations for the step: <step>Add a new method 'start_engine' in the Vehicle impl block</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Add start_engine method",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/vehicle.rs",
|
||||
@@ -222,12 +222,12 @@ impl Employee {
|
||||
|
||||
<step>Make salary an f32</step>
|
||||
|
||||
What are the suggestions for the step: <step>Make salary an f32</step>
|
||||
What are the operations for the step: <step>Make salary an f32</step>
|
||||
|
||||
A (wrong):
|
||||
{
|
||||
"title": "Change salary to f32",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
@@ -248,7 +248,7 @@ This example demonstrates what not to do. `struct Employee salary` is a child of
|
||||
A (corrected):
|
||||
{
|
||||
"title": "Change salary to f32",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "Update",
|
||||
"path": "src/employee.rs",
|
||||
@@ -259,12 +259,12 @@ A (corrected):
|
||||
}
|
||||
|
||||
User:
|
||||
What are the correct suggestions for the step: <step>Remove the 'department' field and update the 'print_details' method</step>
|
||||
What are the correct operations for the step: <step>Remove the 'department' field and update the 'print_details' method</step>
|
||||
|
||||
A:
|
||||
{
|
||||
"title": "Remove department",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "Delete",
|
||||
"path": "src/employee.rs",
|
||||
@@ -311,7 +311,7 @@ impl Game {
|
||||
A:
|
||||
{
|
||||
"title": "Add level field to Player",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "InsertSiblingAfter",
|
||||
"path": "src/game.rs",
|
||||
@@ -349,7 +349,7 @@ impl Config {
|
||||
A:
|
||||
{
|
||||
"title": "Add load_from_file method",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/config.rs",
|
||||
@@ -389,7 +389,7 @@ impl Database {
|
||||
A:
|
||||
{
|
||||
"title": "Add error handling to query",
|
||||
"suggestions": [
|
||||
"operations": [
|
||||
{
|
||||
"kind": "PrependChild",
|
||||
"path": "src/database.rs",
|
||||
@@ -410,4 +410,4 @@ A:
|
||||
]
|
||||
}
|
||||
|
||||
Now generate the suggestions for the following step:
|
||||
Now generate the operations for the following step:
|
||||
|
||||
@@ -400,19 +400,6 @@
|
||||
"model": "gpt-4o"
|
||||
}
|
||||
},
|
||||
// The settings for slash commands.
|
||||
"slash_commands": {
|
||||
// Settings for the `/docs` slash command.
|
||||
"docs": {
|
||||
// Whether `/docs` is enabled.
|
||||
"enabled": false
|
||||
},
|
||||
// Settings for the `/project` slash command.
|
||||
"project": {
|
||||
// Whether `/project` is enabled.
|
||||
"enabled": false
|
||||
}
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
"show_call_status_icon": true,
|
||||
// Whether to use language servers to provide code intelligence.
|
||||
|
||||
@@ -24,7 +24,6 @@ schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
mod supported_countries;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
use strum::EnumIter;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
@@ -18,13 +16,13 @@ pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
|
||||
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
|
||||
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
|
||||
#[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
|
||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
@@ -55,7 +53,7 @@ impl Model {
|
||||
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
|
||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||
Model::Claude3Haiku => "claude-3-haiku-20240307",
|
||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
@@ -98,7 +96,7 @@ pub async fn complete(
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<Response, AnthropicError> {
|
||||
) -> Result<Response> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
@@ -108,40 +106,24 @@ pub async fn complete(
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).context("failed to serialize request")?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::from(serialized_request))
|
||||
.context("failed to construct request body")?;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
let mut response = client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send request to Anthropic")?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("failed to read response body")?;
|
||||
let response_message: Response =
|
||||
serde_json::from_slice(&body).context("failed to deserialize response body")?;
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let response_message: Response = serde_json::from_slice(&body)?;
|
||||
Ok(response_message)
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("failed to read response body")?;
|
||||
let body_str =
|
||||
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
|
||||
Err(AnthropicError::Other(anyhow!(
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str
|
||||
)))
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +133,7 @@ pub async fn stream_completion(
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
|
||||
) -> Result<BoxStream<'static, Result<Event>>> {
|
||||
let request = StreamingRequest {
|
||||
base: request,
|
||||
stream: true,
|
||||
@@ -167,16 +149,10 @@ pub async fn stream_completion(
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
let serialized_request =
|
||||
serde_json::to_string(&request).context("failed to serialize request")?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::from(serialized_request))
|
||||
.context("failed to construct request body")?;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
let mut response = client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send request to Anthropic")?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
@@ -187,41 +163,36 @@ pub async fn stream_completion(
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("failed to read response body")?;
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let body_str =
|
||||
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
match serde_json::from_str::<Event>(body_str) {
|
||||
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
|
||||
Ok(_) => Err(AnthropicError::Other(anyhow!(
|
||||
Ok(Event::Error { error }) => Err(api_error_to_err(error)),
|
||||
Ok(_) => Err(anyhow!(
|
||||
"Unexpected success response while expecting an error: '{body_str}'",
|
||||
))),
|
||||
Err(_) => Err(AnthropicError::Other(anyhow!(
|
||||
)),
|
||||
Err(_) => Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str,
|
||||
))),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
response: impl Stream<Item = Result<Event, AnthropicError>>,
|
||||
) -> impl Stream<Item = Result<String, AnthropicError>> {
|
||||
response: impl Stream<Item = Result<Event>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
response.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
@@ -233,7 +204,7 @@ pub fn extract_text_from_events(
|
||||
ContentDelta::TextDelta { text } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
|
||||
Event::Error { error } => Some(Err(api_error_to_err(error))),
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
@@ -241,6 +212,15 @@ pub fn extract_text_from_events(
|
||||
})
|
||||
}
|
||||
|
||||
fn api_error_to_err(
|
||||
ApiError {
|
||||
error_type,
|
||||
message,
|
||||
}: ApiError,
|
||||
) -> anyhow::Error {
|
||||
anyhow!("API error. Type: '{error_type}', message: '{message}'",)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
@@ -394,53 +374,9 @@ pub struct MessageDelta {
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AnthropicError {
|
||||
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
|
||||
ApiError(ApiError),
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// An Anthropic API error code.
|
||||
/// https://docs.anthropic.com/en/api/errors#http-errors
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum ApiErrorCode {
|
||||
/// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
|
||||
InvalidRequestError,
|
||||
/// 401 - `authentication_error`: There's an issue with your API key.
|
||||
AuthenticationError,
|
||||
/// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
|
||||
PermissionError,
|
||||
/// 404 - `not_found_error`: The requested resource was not found.
|
||||
NotFoundError,
|
||||
/// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
|
||||
RequestTooLarge,
|
||||
/// 429 - `rate_limit_error`: Your account has hit a rate limit.
|
||||
RateLimitError,
|
||||
/// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
|
||||
ApiError,
|
||||
/// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
|
||||
OverloadedError,
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
pub fn code(&self) -> Option<ApiErrorCode> {
|
||||
ApiErrorCode::from_str(&self.error_type).ok()
|
||||
}
|
||||
|
||||
pub fn is_rate_limit_error(&self) -> bool {
|
||||
match self.error_type.as_str() {
|
||||
"rate_limit_error" => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,6 @@ semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smallvec.workspace = true
|
||||
smol.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
terminal.workspace = true
|
||||
|
||||
@@ -3,14 +3,12 @@
|
||||
pub mod assistant_panel;
|
||||
pub mod assistant_settings;
|
||||
mod context;
|
||||
pub(crate) mod context_inspector;
|
||||
pub mod context_store;
|
||||
mod inline_assistant;
|
||||
mod model_selector;
|
||||
mod prompt_library;
|
||||
mod prompts;
|
||||
mod slash_command;
|
||||
pub mod slash_command_settings;
|
||||
mod streaming_diff;
|
||||
mod terminal_inline_assistant;
|
||||
|
||||
@@ -21,7 +19,6 @@ use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use context::*;
|
||||
pub use context_store::*;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use fs::Fs;
|
||||
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||
use indexed_docs::IndexedDocsRegistry;
|
||||
@@ -31,21 +28,18 @@ use language_model::{
|
||||
};
|
||||
pub(crate) use model_selector::*;
|
||||
pub use prompts::PromptBuilder;
|
||||
use prompts::PromptOverrideContext;
|
||||
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{update_settings_file, Settings, SettingsStore};
|
||||
use slash_command::{
|
||||
default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
|
||||
project_command, prompt_command, search_command, symbols_command, tabs_command,
|
||||
terminal_command, workflow_command,
|
||||
active_command, default_command, diagnostics_command, docs_command, fetch_command,
|
||||
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
|
||||
tabs_command, term_command, workflow_command,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
pub(crate) use streaming_diff::*;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::slash_command_settings::SlashCommandSettings;
|
||||
|
||||
actions!(
|
||||
assistant,
|
||||
[
|
||||
@@ -61,7 +55,7 @@ actions!(
|
||||
DeployPromptLibrary,
|
||||
ConfirmCommand,
|
||||
ToggleModelSelector,
|
||||
DebugWorkflowSteps
|
||||
DebugEditSteps
|
||||
]
|
||||
);
|
||||
|
||||
@@ -102,7 +96,6 @@ pub enum MessageStatus {
|
||||
Pending,
|
||||
Done,
|
||||
Error(SharedString),
|
||||
Canceled,
|
||||
}
|
||||
|
||||
impl MessageStatus {
|
||||
@@ -113,7 +106,6 @@ impl MessageStatus {
|
||||
Some(proto::context_message_status::Variant::Error(error)) => {
|
||||
MessageStatus::Error(error.message.into())
|
||||
}
|
||||
Some(proto::context_message_status::Variant::Canceled(_)) => MessageStatus::Canceled,
|
||||
None => MessageStatus::Pending,
|
||||
}
|
||||
}
|
||||
@@ -137,11 +129,6 @@ impl MessageStatus {
|
||||
},
|
||||
)),
|
||||
},
|
||||
MessageStatus::Canceled => proto::ContextMessageStatus {
|
||||
variant: Some(proto::context_message_status::Variant::Canceled(
|
||||
proto::context_message_status::Canceled {},
|
||||
)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,15 +166,9 @@ impl Assistant {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
dev_mode: bool,
|
||||
cx: &mut AppContext,
|
||||
) -> Arc<PromptBuilder> {
|
||||
pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
|
||||
cx.set_global(Assistant::default());
|
||||
AssistantSettings::register(cx);
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
// TODO: remove this when 0.148.0 is released.
|
||||
if AssistantSettings::get_global(cx).using_outdated_settings_version {
|
||||
@@ -220,14 +201,10 @@ pub fn init(
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
|
||||
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
|
||||
dev_mode,
|
||||
fs: fs.clone(),
|
||||
cx,
|
||||
}))
|
||||
.log_err()
|
||||
.map(Arc::new)
|
||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||
let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
|
||||
.log_err()
|
||||
.map(Arc::new)
|
||||
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
|
||||
register_slash_commands(Some(prompt_builder.clone()), cx);
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
@@ -293,15 +270,17 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
|
||||
fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
slash_command_registry.register_command(file_command::FileSlashCommand, true);
|
||||
slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
|
||||
slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
|
||||
slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
|
||||
slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
|
||||
slash_command_registry.register_command(search_command::SearchSlashCommand, true);
|
||||
slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
|
||||
slash_command_registry.register_command(default_command::DefaultSlashCommand, false);
|
||||
slash_command_registry.register_command(terminal_command::TerminalSlashCommand, true);
|
||||
slash_command_registry.register_command(now_command::NowSlashCommand, false);
|
||||
slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
|
||||
slash_command_registry.register_command(term_command::TermSlashCommand, true);
|
||||
slash_command_registry.register_command(now_command::NowSlashCommand, true);
|
||||
slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
|
||||
|
||||
slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
|
||||
if let Some(prompt_builder) = prompt_builder {
|
||||
slash_command_registry.register_command(
|
||||
workflow_command::WorkflowSlashCommand::new(prompt_builder),
|
||||
@@ -309,37 +288,6 @@ fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut
|
||||
);
|
||||
}
|
||||
slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
|
||||
|
||||
update_slash_commands_from_settings(cx);
|
||||
cx.observe_global::<SettingsStore>(update_slash_commands_from_settings)
|
||||
.detach();
|
||||
|
||||
cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
|
||||
let slash_command_registry = slash_command_registry.clone();
|
||||
move |is_enabled, _cx| {
|
||||
if is_enabled {
|
||||
slash_command_registry.register_command(search_command::SearchSlashCommand, true);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn update_slash_commands_from_settings(cx: &mut AppContext) {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
let settings = SlashCommandSettings::get_global(cx);
|
||||
|
||||
if settings.docs.enabled {
|
||||
slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
|
||||
} else {
|
||||
slash_command_registry.unregister_command(docs_command::DocsSlashCommand);
|
||||
}
|
||||
|
||||
if settings.project.enabled {
|
||||
slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
|
||||
} else {
|
||||
slash_command_registry.unregister_command(project_command::ProjectSlashCommand);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn humanize_token_count(count: usize) -> String {
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::{
|
||||
display_map::{BlockDisposition, BlockProperties, BlockStyle, CustomBlockId},
|
||||
Editor,
|
||||
};
|
||||
use gpui::{AppContext, Model, View};
|
||||
use text::{Bias, ToOffset, ToPoint};
|
||||
use ui::{
|
||||
div, h_flex, px, Color, Element as _, ParentElement as _, Styled, ViewContext, WindowContext,
|
||||
};
|
||||
|
||||
use crate::{Context, ResolvedWorkflowStep, WorkflowSuggestion};
|
||||
|
||||
type StepRange = Range<language::Anchor>;
|
||||
|
||||
struct DebugInfo {
|
||||
range: Range<editor::Anchor>,
|
||||
block_id: CustomBlockId,
|
||||
}
|
||||
|
||||
pub(crate) struct ContextInspector {
|
||||
active_debug_views: HashMap<Range<language::Anchor>, DebugInfo>,
|
||||
context: Model<Context>,
|
||||
editor: View<Editor>,
|
||||
}
|
||||
|
||||
impl ContextInspector {
|
||||
pub(crate) fn new(editor: View<Editor>, context: Model<Context>) -> Self {
|
||||
Self {
|
||||
editor,
|
||||
context,
|
||||
active_debug_views: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_active(&self, range: &StepRange) -> bool {
|
||||
self.active_debug_views.contains_key(range)
|
||||
}
|
||||
|
||||
pub(crate) fn refresh(&mut self, range: &StepRange, cx: &mut WindowContext<'_>) {
|
||||
if self.deactivate_for(range, cx) {
|
||||
self.activate_for_step(range.clone(), cx);
|
||||
}
|
||||
}
|
||||
fn crease_content(
|
||||
context: &Model<Context>,
|
||||
range: StepRange,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<Arc<str>> {
|
||||
use std::fmt::Write;
|
||||
let step = context.read(cx).workflow_step_for_range(range)?;
|
||||
let mut output = String::from("\n\n");
|
||||
match &step.status {
|
||||
crate::WorkflowStepStatus::Resolved(ResolvedWorkflowStep { title, suggestions }) => {
|
||||
writeln!(output, "Resolution:").ok()?;
|
||||
writeln!(output, " {title:?}").ok()?;
|
||||
for (buffer, suggestion_groups) in suggestions {
|
||||
let buffer = buffer.read(cx);
|
||||
let buffer_path = buffer
|
||||
.file()
|
||||
.and_then(|file| file.path().to_str())
|
||||
.unwrap_or("untitled");
|
||||
let snapshot = buffer.text_snapshot();
|
||||
writeln!(output, " {buffer_path}:").ok()?;
|
||||
for group in suggestion_groups {
|
||||
for suggestion in &group.suggestions {
|
||||
pretty_print_workflow_suggestion(&mut output, suggestion, &snapshot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::WorkflowStepStatus::Pending(_) => {
|
||||
writeln!(output, "Resolution: Pending").ok()?;
|
||||
}
|
||||
crate::WorkflowStepStatus::Error(error) => {
|
||||
writeln!(output, "Resolution: Error").ok()?;
|
||||
writeln!(output, "{error:?}").ok()?;
|
||||
}
|
||||
}
|
||||
|
||||
Some(output.into())
|
||||
}
|
||||
pub(crate) fn activate_for_step(&mut self, range: StepRange, cx: &mut WindowContext<'_>) {
|
||||
let text = Self::crease_content(&self.context, range.clone(), cx)
|
||||
.unwrap_or_else(|| Arc::from("Error fetching debug info"));
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let buffer = editor.buffer().read(cx).as_singleton()?;
|
||||
let snapshot = buffer.read(cx).text_snapshot();
|
||||
let start_offset = range.end.to_offset(&snapshot) + 1;
|
||||
let start_offset = snapshot.clip_offset(start_offset, Bias::Right);
|
||||
let text_len = text.len();
|
||||
buffer.update(cx, |this, cx| {
|
||||
this.edit([(start_offset..start_offset, text)], None, cx);
|
||||
});
|
||||
|
||||
let end_offset = start_offset + text_len;
|
||||
let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let anchor_before = multibuffer_snapshot.anchor_after(start_offset);
|
||||
let anchor_after = multibuffer_snapshot.anchor_before(end_offset);
|
||||
|
||||
let block_id = editor
|
||||
.insert_blocks(
|
||||
[BlockProperties {
|
||||
position: anchor_after,
|
||||
height: 0,
|
||||
style: BlockStyle::Sticky,
|
||||
render: Box::new(move |cx| {
|
||||
div()
|
||||
.w_full()
|
||||
.px(cx.gutter_dimensions.full_width())
|
||||
.child(h_flex().h(px(1.)).bg(Color::Warning.color(cx)))
|
||||
.into_any()
|
||||
}),
|
||||
disposition: BlockDisposition::Below,
|
||||
priority: 0,
|
||||
}],
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.into_iter()
|
||||
.next()?;
|
||||
let info = DebugInfo {
|
||||
range: anchor_before..anchor_after,
|
||||
block_id,
|
||||
};
|
||||
self.active_debug_views.insert(range, info);
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
|
||||
fn deactivate_impl(editor: &mut Editor, debug_data: DebugInfo, cx: &mut ViewContext<Editor>) {
|
||||
editor.remove_blocks(HashSet::from_iter([debug_data.block_id]), None, cx);
|
||||
editor.edit([(debug_data.range, Arc::<str>::default())], cx)
|
||||
}
|
||||
pub(crate) fn deactivate_for(&mut self, range: &StepRange, cx: &mut WindowContext<'_>) -> bool {
|
||||
if let Some(debug_data) = self.active_debug_views.remove(range) {
|
||||
self.editor.update(cx, |this, cx| {
|
||||
Self::deactivate_impl(this, debug_data, cx);
|
||||
});
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn deactivate(&mut self, cx: &mut WindowContext<'_>) {
|
||||
let steps_to_disable = std::mem::take(&mut self.active_debug_views);
|
||||
|
||||
self.editor.update(cx, move |editor, cx| {
|
||||
for (_, debug_data) in steps_to_disable {
|
||||
Self::deactivate_impl(editor, debug_data, cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
fn pretty_print_anchor(
|
||||
out: &mut String,
|
||||
anchor: &language::Anchor,
|
||||
snapshot: &text::BufferSnapshot,
|
||||
) {
|
||||
use std::fmt::Write;
|
||||
let point = anchor.to_point(snapshot);
|
||||
write!(out, "{}:{}", point.row, point.column).ok();
|
||||
}
|
||||
fn pretty_print_range(
|
||||
out: &mut String,
|
||||
range: &Range<language::Anchor>,
|
||||
snapshot: &text::BufferSnapshot,
|
||||
) {
|
||||
use std::fmt::Write;
|
||||
write!(out, " Range: ").ok();
|
||||
pretty_print_anchor(out, &range.start, snapshot);
|
||||
write!(out, "..").ok();
|
||||
pretty_print_anchor(out, &range.end, snapshot);
|
||||
}
|
||||
|
||||
fn pretty_print_workflow_suggestion(
|
||||
out: &mut String,
|
||||
suggestion: &WorkflowSuggestion,
|
||||
snapshot: &text::BufferSnapshot,
|
||||
) {
|
||||
use std::fmt::Write;
|
||||
let (range, description, position) = match suggestion {
|
||||
WorkflowSuggestion::Update { range, description } => (Some(range), Some(description), None),
|
||||
WorkflowSuggestion::CreateFile { description } => (None, Some(description), None),
|
||||
WorkflowSuggestion::AppendChild {
|
||||
position,
|
||||
description,
|
||||
}
|
||||
| WorkflowSuggestion::InsertSiblingBefore {
|
||||
position,
|
||||
description,
|
||||
}
|
||||
| WorkflowSuggestion::InsertSiblingAfter {
|
||||
position,
|
||||
description,
|
||||
}
|
||||
| WorkflowSuggestion::PrependChild {
|
||||
position,
|
||||
description,
|
||||
} => (None, Some(description), Some(position)),
|
||||
|
||||
WorkflowSuggestion::Delete { range } => (Some(range), None, None),
|
||||
};
|
||||
if let Some(description) = description {
|
||||
writeln!(out, " Description: {description}").ok();
|
||||
}
|
||||
if let Some(range) = range {
|
||||
pretty_print_range(out, range, snapshot);
|
||||
}
|
||||
if let Some(position) = position {
|
||||
write!(out, " Position: ").ok();
|
||||
pretty_print_anchor(out, position, snapshot);
|
||||
write!(out, "\n").ok();
|
||||
}
|
||||
write!(out, "\n").ok();
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
use feature_flags::ZedPro;
|
||||
use gpui::DismissEvent;
|
||||
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
|
||||
use proto::Plan;
|
||||
|
||||
@@ -133,8 +132,6 @@ impl PickerDelegate for ModelPickerDelegate {
|
||||
model.is_selected = model.model.id() == selected_model_id
|
||||
&& model.model.provider_id() == selected_provider_id;
|
||||
}
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +155,7 @@ impl PickerDelegate for ModelPickerDelegate {
|
||||
div().pr_1().child(
|
||||
Icon::new(model_info.provider_icon)
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::Medium),
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
|
||||
@@ -775,7 +775,7 @@ impl PromptLibrary {
|
||||
LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![body.to_string().into()],
|
||||
content: body.to_string(),
|
||||
}],
|
||||
stop: Vec::new(),
|
||||
temperature: 1.,
|
||||
|
||||
@@ -12,13 +12,11 @@ use util::ResultExt;
|
||||
pub struct ContentPromptContext {
|
||||
pub content_type: String,
|
||||
pub language_name: Option<String>,
|
||||
pub is_insert: bool,
|
||||
pub is_truncated: bool,
|
||||
pub document_content: String,
|
||||
pub user_prompt: String,
|
||||
pub rewrite_section: String,
|
||||
pub rewrite_section_with_selections: String,
|
||||
pub has_insertion: bool,
|
||||
pub has_replacement: bool,
|
||||
pub rewrite_section: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -35,54 +33,41 @@ pub struct PromptBuilder {
|
||||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
}
|
||||
|
||||
pub struct PromptOverrideContext<'a> {
|
||||
pub dev_mode: bool,
|
||||
pub fs: Arc<dyn Fs>,
|
||||
pub cx: &'a mut gpui::AppContext,
|
||||
}
|
||||
|
||||
impl PromptBuilder {
|
||||
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
|
||||
pub fn new(
|
||||
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
|
||||
) -> Result<Self, Box<TemplateError>> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
Self::register_templates(&mut handlebars)?;
|
||||
|
||||
let handlebars = Arc::new(Mutex::new(handlebars));
|
||||
|
||||
if let Some(override_cx) = override_cx {
|
||||
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
|
||||
if let Some((fs, cx)) = fs_and_cx {
|
||||
Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
|
||||
}
|
||||
|
||||
Ok(Self { handlebars })
|
||||
}
|
||||
|
||||
fn watch_fs_for_template_overrides(
|
||||
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &gpui::AppContext,
|
||||
handlebars: Arc<Mutex<Handlebars<'static>>>,
|
||||
) {
|
||||
let templates_dir = paths::prompt_templates_dir();
|
||||
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let templates_dir = if dev_mode {
|
||||
std::env::current_dir()
|
||||
.ok()
|
||||
.and_then(|pwd| {
|
||||
let pwd_assets_prompts = pwd.join("assets").join("prompts");
|
||||
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
|
||||
})
|
||||
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
|
||||
} else {
|
||||
paths::prompt_overrides_dir().clone()
|
||||
};
|
||||
|
||||
// Create the prompt templates directory if it doesn't exist
|
||||
if !fs.is_dir(&templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(&templates_dir).await {
|
||||
if !fs.is_dir(templates_dir).await {
|
||||
if let Err(e) = fs.create_dir(templates_dir).await {
|
||||
log::error!("Failed to create prompt templates directory: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Initial scan of the prompts directory
|
||||
if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
|
||||
if let Ok(mut entries) = fs.read_dir(templates_dir).await {
|
||||
while let Some(Ok(file_path)) = entries.next().await {
|
||||
if file_path.to_string_lossy().ends_with(".hbs") {
|
||||
if let Ok(content) = fs.load(&file_path).await {
|
||||
@@ -110,7 +95,7 @@ impl PromptBuilder {
|
||||
}
|
||||
|
||||
// Watch for changes
|
||||
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
|
||||
let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
|
||||
while let Some(changed_paths) = changes.next().await {
|
||||
for changed_path in changed_paths {
|
||||
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
|
||||
@@ -162,8 +147,7 @@ impl PromptBuilder {
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
transform_range: Range<usize>,
|
||||
selected_ranges: Vec<Range<usize>>,
|
||||
range: Range<usize>,
|
||||
) -> Result<String, RenderError> {
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => "text",
|
||||
@@ -171,20 +155,21 @@ impl PromptBuilder {
|
||||
};
|
||||
|
||||
const MAX_CTX: usize = 50000;
|
||||
let is_insert = range.is_empty();
|
||||
let mut is_truncated = false;
|
||||
|
||||
let before_range = 0..transform_range.start;
|
||||
let before_range = 0..range.start;
|
||||
let truncated_before = if before_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
transform_range.start - MAX_CTX..transform_range.start
|
||||
range.start - MAX_CTX..range.start
|
||||
} else {
|
||||
before_range
|
||||
};
|
||||
|
||||
let after_range = transform_range.end..buffer.len();
|
||||
let after_range = range.end..buffer.len();
|
||||
let truncated_after = if after_range.len() > MAX_CTX {
|
||||
is_truncated = true;
|
||||
transform_range.end..transform_range.end + MAX_CTX
|
||||
range.end..range.end + MAX_CTX
|
||||
} else {
|
||||
after_range
|
||||
};
|
||||
@@ -193,61 +178,37 @@ impl PromptBuilder {
|
||||
for chunk in buffer.text_for_range(truncated_before) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
if is_insert {
|
||||
document_content.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
document_content.push_str("<rewrite_this>\n");
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
}
|
||||
document_content.push_str("\n</rewrite_this>");
|
||||
|
||||
for chunk in buffer.text_for_range(truncated_after) {
|
||||
document_content.push_str(chunk);
|
||||
}
|
||||
|
||||
let mut rewrite_section = String::new();
|
||||
for chunk in buffer.text_for_range(transform_range.clone()) {
|
||||
rewrite_section.push_str(chunk);
|
||||
}
|
||||
|
||||
let rewrite_section_with_selections = {
|
||||
let mut section_with_selections = String::new();
|
||||
let mut last_end = 0;
|
||||
for selected_range in &selected_ranges {
|
||||
if selected_range.start > last_end {
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[last_end..selected_range.start - transform_range.start],
|
||||
);
|
||||
}
|
||||
if selected_range.start == selected_range.end {
|
||||
section_with_selections.push_str("<insert_here></insert_here>");
|
||||
} else {
|
||||
section_with_selections.push_str("<edit_here>");
|
||||
section_with_selections.push_str(
|
||||
&rewrite_section[selected_range.start - transform_range.start
|
||||
..selected_range.end - transform_range.start],
|
||||
);
|
||||
section_with_selections.push_str("</edit_here>");
|
||||
}
|
||||
last_end = selected_range.end - transform_range.start;
|
||||
let rewrite_section = if !is_insert {
|
||||
let mut section = String::new();
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
section.push_str(chunk);
|
||||
}
|
||||
if last_end < rewrite_section.len() {
|
||||
section_with_selections.push_str(&rewrite_section[last_end..]);
|
||||
}
|
||||
section_with_selections
|
||||
Some(section)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
|
||||
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
|
||||
|
||||
let context = ContentPromptContext {
|
||||
content_type: content_type.to_string(),
|
||||
language_name: language_name.map(|s| s.to_string()),
|
||||
is_insert,
|
||||
is_truncated,
|
||||
document_content,
|
||||
user_prompt,
|
||||
rewrite_section,
|
||||
rewrite_section_with_selections,
|
||||
has_insertion,
|
||||
has_replacement,
|
||||
};
|
||||
|
||||
self.handlebars.lock().render("content_prompt", &context)
|
||||
|
||||
@@ -6,7 +6,6 @@ use fuzzy::{match_strings, StringMatchCandidate};
|
||||
use gpui::{AppContext, Model, Task, ViewContext, WeakView, WindowContext};
|
||||
use language::{Anchor, Buffer, CodeLabel, Documentation, HighlightId, LanguageServerId, ToPoint};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use project::CompletionIntent;
|
||||
use rope::Point;
|
||||
use std::{
|
||||
ops::Range,
|
||||
@@ -18,6 +17,7 @@ use std::{
|
||||
use ui::ActiveTheme;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub mod active_command;
|
||||
pub mod default_command;
|
||||
pub mod diagnostics_command;
|
||||
pub mod docs_command;
|
||||
@@ -29,7 +29,7 @@ pub mod prompt_command;
|
||||
pub mod search_command;
|
||||
pub mod symbols_command;
|
||||
pub mod tabs_command;
|
||||
pub mod terminal_command;
|
||||
pub mod term_command;
|
||||
pub mod workflow_command;
|
||||
|
||||
pub(crate) struct SlashCommandCompletionProvider {
|
||||
@@ -107,24 +107,20 @@ impl SlashCommandCompletionProvider {
|
||||
let command_range = command_range.clone();
|
||||
let editor = editor.clone();
|
||||
let workspace = workspace.clone();
|
||||
Arc::new(
|
||||
move |intent: CompletionIntent, cx: &mut WindowContext| {
|
||||
if intent.is_complete() {
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.run_command(
|
||||
command_range.clone(),
|
||||
&command_name,
|
||||
None,
|
||||
true,
|
||||
workspace.clone(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
},
|
||||
) as Arc<_>
|
||||
Arc::new(move |cx: &mut WindowContext| {
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.run_command(
|
||||
command_range.clone(),
|
||||
&command_name,
|
||||
None,
|
||||
true,
|
||||
workspace.clone(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}) as Arc<_>
|
||||
})
|
||||
},
|
||||
);
|
||||
@@ -156,6 +152,7 @@ impl SlashCommandCompletionProvider {
|
||||
let mut flag = self.cancel_flag.lock();
|
||||
flag.store(true, SeqCst);
|
||||
*flag = new_cancel_flag.clone();
|
||||
|
||||
let commands = SlashCommandRegistry::global(cx);
|
||||
if let Some(command) = commands.command(command_name) {
|
||||
let completions = command.complete_argument(
|
||||
@@ -181,21 +178,19 @@ impl SlashCommandCompletionProvider {
|
||||
let command_range = command_range.clone();
|
||||
let command_name = command_name.clone();
|
||||
let command_argument = command_argument.new_text.clone();
|
||||
move |intent: CompletionIntent, cx: &mut WindowContext| {
|
||||
if intent.is_complete() {
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.run_command(
|
||||
command_range.clone(),
|
||||
&command_name,
|
||||
Some(&command_argument),
|
||||
true,
|
||||
workspace.clone(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
move |cx: &mut WindowContext| {
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.run_command(
|
||||
command_range.clone(),
|
||||
&command_name,
|
||||
Some(&command_argument),
|
||||
true,
|
||||
workspace.clone(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}) as Arc<_>
|
||||
})
|
||||
@@ -210,7 +205,7 @@ impl SlashCommandCompletionProvider {
|
||||
|
||||
project::Completion {
|
||||
old_range: argument_range.clone(),
|
||||
label: command_argument.label,
|
||||
label: CodeLabel::plain(command_argument.label, None),
|
||||
new_text,
|
||||
documentation: None,
|
||||
server_id: LanguageServerId(0),
|
||||
|
||||
102
crates/assistant/src/slash_command/active_command.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use super::{
|
||||
diagnostics_command::write_single_file_diagnostics,
|
||||
file_command::{build_entry_output_section, codeblock_fence_for_path},
|
||||
SlashCommand, SlashCommandOutput,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::ArgumentCompletion;
|
||||
use editor::Editor;
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use ui::WindowContext;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub(crate) struct ActiveSlashCommand;
|
||||
|
||||
impl SlashCommand for ActiveSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"active".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert active tab".into()
|
||||
}
|
||||
|
||||
fn menu_text(&self) -> String {
|
||||
"Insert Active Tab".into()
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
self: Arc<Self>,
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_argument: Option<&str>,
|
||||
workspace: WeakView<Workspace>,
|
||||
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<SlashCommandOutput>> {
|
||||
let output = workspace.update(cx, |workspace, cx| {
|
||||
let Some(active_item) = workspace.active_item(cx) else {
|
||||
return Task::ready(Err(anyhow!("no active tab")));
|
||||
};
|
||||
let Some(buffer) = active_item
|
||||
.downcast::<Editor>()
|
||||
.and_then(|editor| editor.read(cx).buffer().read(cx).as_singleton())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("active tab is not an editor")));
|
||||
};
|
||||
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let path = snapshot.resolve_file_path(cx, true);
|
||||
let task = cx.background_executor().spawn({
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let mut output = String::new();
|
||||
output.push_str(&codeblock_fence_for_path(path.as_deref(), None));
|
||||
for chunk in snapshot.as_rope().chunks() {
|
||||
output.push_str(chunk);
|
||||
}
|
||||
if !output.ends_with('\n') {
|
||||
output.push('\n');
|
||||
}
|
||||
output.push_str("```\n");
|
||||
let has_diagnostics =
|
||||
write_single_file_diagnostics(&mut output, path.as_deref(), &snapshot);
|
||||
if output.ends_with('\n') {
|
||||
output.pop();
|
||||
}
|
||||
(output, has_diagnostics)
|
||||
}
|
||||
});
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let (text, has_diagnostics) = task.await;
|
||||
let range = 0..text.len();
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
sections: vec![build_entry_output_section(
|
||||
range,
|
||||
path.as_deref(),
|
||||
false,
|
||||
None,
|
||||
)],
|
||||
run_commands_in_text: has_diagnostics,
|
||||
})
|
||||
})
|
||||
});
|
||||
output.unwrap_or_else(|error| Task::ready(Err(error)))
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use super::{SlashCommand, SlashCommandOutput};
|
||||
use crate::prompt_library::PromptStore;
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::{
|
||||
fmt::Write,
|
||||
@@ -35,7 +35,7 @@ impl SlashCommand for DefaultSlashCommand {
|
||||
_query: String,
|
||||
_cancellation_flag: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
@@ -43,7 +43,6 @@ impl DiagnosticsSlashCommand {
|
||||
worktree_id: entry.worktree_id.to_usize(),
|
||||
path: entry.path.clone(),
|
||||
path_prefix: path_prefix.clone(),
|
||||
is_dir: false, // Diagnostics can't be produced for directories
|
||||
distance_to_relative_ancestor: 0,
|
||||
})
|
||||
.collect(),
|
||||
@@ -108,7 +107,7 @@ impl SlashCommand for DiagnosticsSlashCommand {
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
let Some(workspace) = workspace.and_then(|workspace| workspace.upgrade()) else {
|
||||
return Task::ready(Err(anyhow!("workspace was dropped")));
|
||||
@@ -147,7 +146,7 @@ impl SlashCommand for DiagnosticsSlashCommand {
|
||||
Ok(matches
|
||||
.into_iter()
|
||||
.map(|completion| ArgumentCompletion {
|
||||
label: completion.clone().into(),
|
||||
label: completion.clone(),
|
||||
new_text: completion,
|
||||
run_command: true,
|
||||
})
|
||||
@@ -169,66 +168,58 @@ impl SlashCommand for DiagnosticsSlashCommand {
|
||||
let options = Options::parse(argument);
|
||||
|
||||
let task = collect_diagnostics(workspace.read(cx).project().clone(), options, cx);
|
||||
|
||||
cx.spawn(move |_| async move {
|
||||
let Some((text, sections)) = task.await? else {
|
||||
return Ok(SlashCommandOutput {
|
||||
sections: vec![SlashCommandOutputSection {
|
||||
range: 0..1,
|
||||
icon: IconName::Library,
|
||||
label: "No Diagnostics".into(),
|
||||
}],
|
||||
text: "\n".to_string(),
|
||||
run_commands_in_text: true,
|
||||
});
|
||||
return Ok(SlashCommandOutput::default());
|
||||
};
|
||||
|
||||
let sections = sections
|
||||
.into_iter()
|
||||
.map(|(range, placeholder_type)| SlashCommandOutputSection {
|
||||
range,
|
||||
icon: match placeholder_type {
|
||||
PlaceholderType::Root(_, _) => IconName::ExclamationTriangle,
|
||||
PlaceholderType::File(_) => IconName::File,
|
||||
PlaceholderType::Diagnostic(DiagnosticType::Error, _) => IconName::XCircle,
|
||||
PlaceholderType::Diagnostic(DiagnosticType::Warning, _) => {
|
||||
IconName::ExclamationTriangle
|
||||
}
|
||||
},
|
||||
label: match placeholder_type {
|
||||
PlaceholderType::Root(summary, source) => {
|
||||
let mut label = String::new();
|
||||
label.push_str("Diagnostics");
|
||||
if let Some(source) = source {
|
||||
write!(label, " ({})", source).unwrap();
|
||||
}
|
||||
|
||||
if summary.error_count > 0 || summary.warning_count > 0 {
|
||||
label.push(':');
|
||||
|
||||
if summary.error_count > 0 {
|
||||
write!(label, " {} errors", summary.error_count).unwrap();
|
||||
if summary.warning_count > 0 {
|
||||
label.push_str(",");
|
||||
}
|
||||
}
|
||||
|
||||
if summary.warning_count > 0 {
|
||||
write!(label, " {} warnings", summary.warning_count).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
label.into()
|
||||
}
|
||||
PlaceholderType::File(file_path) => file_path.into(),
|
||||
PlaceholderType::Diagnostic(_, message) => message.into(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
sections,
|
||||
sections: sections
|
||||
.into_iter()
|
||||
.map(|(range, placeholder_type)| SlashCommandOutputSection {
|
||||
range,
|
||||
icon: match placeholder_type {
|
||||
PlaceholderType::Root(_, _) => IconName::ExclamationTriangle,
|
||||
PlaceholderType::File(_) => IconName::File,
|
||||
PlaceholderType::Diagnostic(DiagnosticType::Error, _) => {
|
||||
IconName::XCircle
|
||||
}
|
||||
PlaceholderType::Diagnostic(DiagnosticType::Warning, _) => {
|
||||
IconName::ExclamationTriangle
|
||||
}
|
||||
},
|
||||
label: match placeholder_type {
|
||||
PlaceholderType::Root(summary, source) => {
|
||||
let mut label = String::new();
|
||||
label.push_str("Diagnostics");
|
||||
if let Some(source) = source {
|
||||
write!(label, " ({})", source).unwrap();
|
||||
}
|
||||
|
||||
if summary.error_count > 0 || summary.warning_count > 0 {
|
||||
label.push(':');
|
||||
|
||||
if summary.error_count > 0 {
|
||||
write!(label, " {} errors", summary.error_count).unwrap();
|
||||
if summary.warning_count > 0 {
|
||||
label.push_str(",");
|
||||
}
|
||||
}
|
||||
|
||||
if summary.warning_count > 0 {
|
||||
write!(label, " {} warnings", summary.warning_count)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
label.into()
|
||||
}
|
||||
PlaceholderType::File(file_path) => file_path.into(),
|
||||
PlaceholderType::Diagnostic(_, message) => message.into(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
run_commands_in_text: false,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -164,7 +164,7 @@ impl SlashCommand for DocsSlashCommand {
|
||||
query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
self.ensure_rust_doc_providers_are_registered(workspace, cx);
|
||||
|
||||
@@ -182,7 +182,7 @@ impl SlashCommand for DocsSlashCommand {
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| ArgumentCompletion {
|
||||
label: item.clone().into(),
|
||||
label: item.clone(),
|
||||
new_text: format!("{provider} {item}"),
|
||||
run_command: true,
|
||||
})
|
||||
@@ -194,7 +194,7 @@ impl SlashCommand for DocsSlashCommand {
|
||||
let providers = indexed_docs_registry.list_providers();
|
||||
if providers.is_empty() {
|
||||
return Ok(vec![ArgumentCompletion {
|
||||
label: "No available docs providers.".into(),
|
||||
label: "No available docs providers.".to_string(),
|
||||
new_text: String::new(),
|
||||
run_command: false,
|
||||
}]);
|
||||
@@ -203,7 +203,7 @@ impl SlashCommand for DocsSlashCommand {
|
||||
Ok(providers
|
||||
.into_iter()
|
||||
.map(|provider| ArgumentCompletion {
|
||||
label: provider.to_string().into(),
|
||||
label: provider.to_string(),
|
||||
new_text: provider.to_string(),
|
||||
run_command: false,
|
||||
})
|
||||
@@ -222,38 +222,8 @@ impl SlashCommand for DocsSlashCommand {
|
||||
drop(store.clone().index(package.as_str().into()));
|
||||
}
|
||||
|
||||
let suggested_packages = store.clone().suggest_packages().await?;
|
||||
let search_results = store.search(package).await;
|
||||
|
||||
let mut items = build_completions(provider.clone(), search_results);
|
||||
let workspace_crate_completions = suggested_packages
|
||||
.into_iter()
|
||||
.filter(|package_name| {
|
||||
!items
|
||||
.iter()
|
||||
.any(|item| item.label.text() == package_name.as_ref())
|
||||
})
|
||||
.map(|package_name| ArgumentCompletion {
|
||||
label: format!("{package_name} (unindexed)").into(),
|
||||
new_text: format!("{provider} {package_name}"),
|
||||
run_command: true,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
items.extend(workspace_crate_completions);
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(vec![ArgumentCompletion {
|
||||
label: format!(
|
||||
"Enter a {package_term} name.",
|
||||
package_term = package_term(&provider)
|
||||
)
|
||||
.into(),
|
||||
new_text: provider.to_string(),
|
||||
run_command: false,
|
||||
}]);
|
||||
}
|
||||
|
||||
Ok(items)
|
||||
let items = store.search(package).await;
|
||||
Ok(build_completions(provider, items))
|
||||
}
|
||||
DocsSlashCommandArgs::SearchItemDocs {
|
||||
provider,
|
||||
@@ -299,13 +269,6 @@ impl SlashCommand for DocsSlashCommand {
|
||||
} => (provider, item_path),
|
||||
};
|
||||
|
||||
if key.trim().is_empty() {
|
||||
bail!(
|
||||
"no {package_term} name provided",
|
||||
package_term = package_term(&provider)
|
||||
);
|
||||
}
|
||||
|
||||
let store = store?;
|
||||
|
||||
if let Some(package) = args.package() {
|
||||
@@ -428,15 +391,6 @@ impl DocsSlashCommandArgs {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the term used to refer to a package.
|
||||
fn package_term(provider: &ProviderId) -> &'static str {
|
||||
if provider == &DocsDotRsProvider::id() || provider == &LocalRustdocProvider::id() {
|
||||
return "crate";
|
||||
}
|
||||
|
||||
"package"
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -8,7 +8,7 @@ use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
};
|
||||
use futures::AsyncReadExt;
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use language::LspAdapterDelegate;
|
||||
@@ -120,7 +120,7 @@ impl SlashCommand for FetchSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
}
|
||||
@@ -150,10 +150,6 @@ impl SlashCommand for FetchSlashCommand {
|
||||
let url = SharedString::from(url);
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let text = text.await?;
|
||||
if text.trim().is_empty() {
|
||||
bail!("no textual content found");
|
||||
}
|
||||
|
||||
let range = 0..text.len();
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||
use fuzzy::PathMatch;
|
||||
use gpui::{AppContext, Model, Task, View, WeakView};
|
||||
use language::{BufferSnapshot, CodeLabel, HighlightId, LineEnding, LspAdapterDelegate};
|
||||
use language::{BufferSnapshot, LineEnding, LspAdapterDelegate};
|
||||
use project::{PathMatchCandidateSet, Project};
|
||||
use std::{
|
||||
fmt::Write,
|
||||
@@ -29,30 +29,11 @@ impl FileSlashCommand {
|
||||
let workspace = workspace.read(cx);
|
||||
let project = workspace.project().read(cx);
|
||||
let entries = workspace.recent_navigation_history(Some(10), cx);
|
||||
|
||||
let entries = entries
|
||||
.into_iter()
|
||||
.map(|entries| (entries.0, false))
|
||||
.chain(project.worktrees(cx).flat_map(|worktree| {
|
||||
let worktree = worktree.read(cx);
|
||||
let id = worktree.id();
|
||||
worktree.child_entries(Path::new("")).map(move |entry| {
|
||||
(
|
||||
project::ProjectPath {
|
||||
worktree_id: id,
|
||||
path: entry.path.clone(),
|
||||
},
|
||||
entry.kind.is_dir(),
|
||||
)
|
||||
})
|
||||
}))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let path_prefix: Arc<str> = Arc::default();
|
||||
Task::ready(
|
||||
entries
|
||||
.into_iter()
|
||||
.filter_map(|(entry, is_dir)| {
|
||||
.filter_map(|(entry, _)| {
|
||||
let worktree = project.worktree_for_id(entry.worktree_id, cx)?;
|
||||
let mut full_path = PathBuf::from(worktree.read(cx).root_name());
|
||||
full_path.push(&entry.path);
|
||||
@@ -63,7 +44,6 @@ impl FileSlashCommand {
|
||||
path: full_path.into(),
|
||||
path_prefix: path_prefix.clone(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
is_dir,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
@@ -74,7 +54,6 @@ impl FileSlashCommand {
|
||||
.into_iter()
|
||||
.map(|worktree| {
|
||||
let worktree = worktree.read(cx);
|
||||
|
||||
PathMatchCandidateSet {
|
||||
snapshot: worktree.snapshot(),
|
||||
include_ignored: worktree
|
||||
@@ -125,42 +104,29 @@ impl SlashCommand for FileSlashCommand {
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
let Some(workspace) = workspace.and_then(|workspace| workspace.upgrade()) else {
|
||||
return Task::ready(Err(anyhow!("workspace was dropped")));
|
||||
};
|
||||
|
||||
let paths = self.search_paths(query, cancellation_flag, &workspace, cx);
|
||||
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
|
||||
cx.background_executor().spawn(async move {
|
||||
Ok(paths
|
||||
.await
|
||||
.into_iter()
|
||||
.filter_map(|path_match| {
|
||||
.map(|path_match| {
|
||||
let text = format!(
|
||||
"{}{}",
|
||||
path_match.path_prefix,
|
||||
path_match.path.to_string_lossy()
|
||||
);
|
||||
|
||||
let mut label = CodeLabel::default();
|
||||
let file_name = path_match.path.file_name()?.to_string_lossy();
|
||||
let label_text = if path_match.is_dir {
|
||||
format!("{}/ ", file_name)
|
||||
} else {
|
||||
format!("{} ", file_name)
|
||||
};
|
||||
|
||||
label.push_str(label_text.as_str(), None);
|
||||
label.push_str(&text, comment_id);
|
||||
label.filter_range = 0..file_name.len();
|
||||
|
||||
Some(ArgumentCompletion {
|
||||
label,
|
||||
ArgumentCompletion {
|
||||
label: text.clone(),
|
||||
new_text: text,
|
||||
run_command: true,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
|
||||
@@ -6,7 +6,7 @@ use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
};
|
||||
use chrono::Local;
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use ui::prelude::*;
|
||||
use workspace::Workspace;
|
||||
@@ -35,7 +35,7 @@ impl SlashCommand for NowSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ impl SlashCommand for ProjectSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{SlashCommand, SlashCommandOutput};
|
||||
use crate::prompt_library::PromptStore;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
use ui::prelude::*;
|
||||
@@ -32,7 +32,7 @@ impl SlashCommand for PromptSlashCommand {
|
||||
query: String,
|
||||
_cancellation_flag: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
let store = PromptStore::global(cx);
|
||||
cx.background_executor().spawn(async move {
|
||||
@@ -42,7 +42,7 @@ impl SlashCommand for PromptSlashCommand {
|
||||
.filter_map(|prompt| {
|
||||
let prompt_title = prompt.title?.to_string();
|
||||
Some(ArgumentCompletion {
|
||||
label: prompt_title.clone().into(),
|
||||
label: prompt_title.clone(),
|
||||
new_text: prompt_title,
|
||||
run_command: true,
|
||||
})
|
||||
@@ -77,11 +77,6 @@ impl SlashCommand for PromptSlashCommand {
|
||||
});
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let mut prompt = prompt.await?;
|
||||
|
||||
if prompt.starts_with('/') {
|
||||
// Prevent an edge case where the inserted prompt starts with a slash command (that leads to funky rendering).
|
||||
prompt.insert(0, '\n');
|
||||
}
|
||||
if prompt.is_empty() {
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ use super::{
|
||||
};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||
use feature_flags::FeatureFlag;
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::{CodeLabel, LineEnding, LspAdapterDelegate};
|
||||
use semantic_index::SemanticIndex;
|
||||
@@ -18,12 +17,6 @@ use ui::{prelude::*, IconName};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub(crate) struct SearchSlashCommandFeatureFlag;
|
||||
|
||||
impl FeatureFlag for SearchSlashCommandFeatureFlag {
|
||||
const NAME: &'static str = "search-slash-command";
|
||||
}
|
||||
|
||||
pub(crate) struct SearchSlashCommand;
|
||||
|
||||
impl SlashCommand for SearchSlashCommand {
|
||||
@@ -52,7 +45,7 @@ impl SlashCommand for SearchSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::{SlashCommand, SlashCommandOutput};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
|
||||
use editor::Editor;
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::sync::Arc;
|
||||
use std::{path::Path, sync::atomic::AtomicBool};
|
||||
@@ -29,7 +29,7 @@ impl SlashCommand for OutlineSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
@@ -3,35 +3,29 @@ use super::{
|
||||
file_command::{build_entry_output_section, codeblock_fence_for_path},
|
||||
SlashCommand, SlashCommandOutput,
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_slash_command::ArgumentCompletion;
|
||||
use collections::HashMap;
|
||||
use editor::Editor;
|
||||
use gpui::{Entity, Task, WeakView};
|
||||
use language::{BufferSnapshot, LspAdapterDelegate};
|
||||
use std::{
|
||||
fmt::Write,
|
||||
path::PathBuf,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
use gpui::{AppContext, Entity, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use std::{fmt::Write, sync::Arc};
|
||||
use ui::WindowContext;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub(crate) struct TabsSlashCommand;
|
||||
|
||||
const ALL_TABS_COMPLETION_ITEM: &str = "all";
|
||||
|
||||
impl SlashCommand for TabsSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"tabs".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert open tabs (active tab by default)".to_owned()
|
||||
"insert open tabs".into()
|
||||
}
|
||||
|
||||
fn menu_text(&self) -> String {
|
||||
"Insert Open Tabs".to_owned()
|
||||
"Insert Open Tabs".into()
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
@@ -40,188 +34,85 @@ impl SlashCommand for TabsSlashCommand {
|
||||
|
||||
fn complete_argument(
|
||||
self: Arc<Self>,
|
||||
query: String,
|
||||
cancel: Arc<AtomicBool>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
_query: String,
|
||||
_cancel: Arc<std::sync::atomic::AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
let all_tabs_completion_item = if ALL_TABS_COMPLETION_ITEM.contains(&query) {
|
||||
Some(ArgumentCompletion {
|
||||
label: ALL_TABS_COMPLETION_ITEM.into(),
|
||||
new_text: ALL_TABS_COMPLETION_ITEM.to_owned(),
|
||||
run_command: true,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tab_items_search = tab_items_for_query(workspace, query, cancel, false, cx);
|
||||
cx.spawn(|_| async move {
|
||||
let tab_completion_items =
|
||||
tab_items_search
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|(path, ..)| {
|
||||
let path_string = path.as_deref()?.to_string_lossy().to_string();
|
||||
Some(ArgumentCompletion {
|
||||
label: path_string.clone().into(),
|
||||
new_text: path_string,
|
||||
run_command: true,
|
||||
})
|
||||
});
|
||||
Ok(all_tabs_completion_item
|
||||
.into_iter()
|
||||
.chain(tab_completion_items)
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
argument: Option<&str>,
|
||||
_argument: Option<&str>,
|
||||
workspace: WeakView<Workspace>,
|
||||
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<SlashCommandOutput>> {
|
||||
let tab_items_search = tab_items_for_query(
|
||||
Some(workspace),
|
||||
argument.map(ToOwned::to_owned).unwrap_or_default(),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
true,
|
||||
cx,
|
||||
);
|
||||
let open_buffers = workspace.update(cx, |workspace, cx| {
|
||||
let mut timestamps_by_entity_id = HashMap::default();
|
||||
let mut open_buffers = Vec::new();
|
||||
|
||||
cx.background_executor().spawn(async move {
|
||||
let mut sections = Vec::new();
|
||||
let mut text = String::new();
|
||||
let mut has_diagnostics = false;
|
||||
for (full_path, buffer, _) in tab_items_search.await? {
|
||||
let section_start_ix = text.len();
|
||||
text.push_str(&codeblock_fence_for_path(full_path.as_deref(), None));
|
||||
for chunk in buffer.as_rope().chunks() {
|
||||
text.push_str(chunk);
|
||||
for pane in workspace.panes() {
|
||||
let pane = pane.read(cx);
|
||||
for entry in pane.activation_history() {
|
||||
timestamps_by_entity_id.insert(entry.entity_id, entry.timestamp);
|
||||
}
|
||||
if !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
writeln!(text, "```").unwrap();
|
||||
if write_single_file_diagnostics(&mut text, full_path.as_deref(), &buffer) {
|
||||
has_diagnostics = true;
|
||||
}
|
||||
if !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
|
||||
let section_end_ix = text.len() - 1;
|
||||
sections.push(build_entry_output_section(
|
||||
section_start_ix..section_end_ix,
|
||||
full_path.as_deref(),
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
sections,
|
||||
run_commands_in_text: has_diagnostics,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn tab_items_for_query(
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
mut query: String,
|
||||
cancel: Arc<AtomicBool>,
|
||||
use_active_tab_for_empty_query: bool,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<anyhow::Result<Vec<(Option<PathBuf>, BufferSnapshot, usize)>>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
query.make_ascii_lowercase();
|
||||
let mut open_buffers =
|
||||
workspace
|
||||
.context("no workspace")?
|
||||
.update(&mut cx, |workspace, cx| {
|
||||
if use_active_tab_for_empty_query && query.trim().is_empty() {
|
||||
let active_editor = workspace
|
||||
.active_item(cx)
|
||||
.context("no active item")?
|
||||
.downcast::<Editor>()
|
||||
.context("active item is not an editor")?;
|
||||
let snapshot = active_editor
|
||||
.read(cx)
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.as_singleton()
|
||||
.context("active editor is not a singleton buffer")?
|
||||
.read(cx)
|
||||
.snapshot();
|
||||
for editor in workspace.items_of_type::<Editor>(cx) {
|
||||
if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
|
||||
if let Some(timestamp) = timestamps_by_entity_id.get(&editor.entity_id()) {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let full_path = snapshot.resolve_file_path(cx, true);
|
||||
return anyhow::Ok(vec![(full_path, snapshot, 0)]);
|
||||
open_buffers.push((full_path, snapshot, *timestamp));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut timestamps_by_entity_id = HashMap::default();
|
||||
let mut open_buffers = Vec::new();
|
||||
open_buffers
|
||||
});
|
||||
|
||||
for pane in workspace.panes() {
|
||||
let pane = pane.read(cx);
|
||||
for entry in pane.activation_history() {
|
||||
timestamps_by_entity_id.insert(entry.entity_id, entry.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
for editor in workspace.items_of_type::<Editor>(cx) {
|
||||
if let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
|
||||
if let Some(timestamp) =
|
||||
timestamps_by_entity_id.get(&editor.entity_id())
|
||||
{
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let full_path = snapshot.resolve_file_path(cx, true);
|
||||
open_buffers.push((full_path, snapshot, *timestamp));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(open_buffers)
|
||||
})??;
|
||||
|
||||
let background_executor = cx.background_executor().clone();
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
match open_buffers {
|
||||
Ok(mut open_buffers) => cx.background_executor().spawn(async move {
|
||||
open_buffers.sort_by_key(|(_, _, timestamp)| *timestamp);
|
||||
let query = query.trim();
|
||||
if query.is_empty() || query == ALL_TABS_COMPLETION_ITEM {
|
||||
return Ok(open_buffers);
|
||||
|
||||
let mut sections = Vec::new();
|
||||
let mut text = String::new();
|
||||
let mut has_diagnostics = false;
|
||||
for (full_path, buffer, _) in open_buffers {
|
||||
let section_start_ix = text.len();
|
||||
text.push_str(&codeblock_fence_for_path(full_path.as_deref(), None));
|
||||
for chunk in buffer.as_rope().chunks() {
|
||||
text.push_str(chunk);
|
||||
}
|
||||
if !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
writeln!(text, "```").unwrap();
|
||||
if write_single_file_diagnostics(&mut text, full_path.as_deref(), &buffer) {
|
||||
has_diagnostics = true;
|
||||
}
|
||||
if !text.ends_with('\n') {
|
||||
text.push('\n');
|
||||
}
|
||||
|
||||
let section_end_ix = text.len() - 1;
|
||||
sections.push(build_entry_output_section(
|
||||
section_start_ix..section_end_ix,
|
||||
full_path.as_deref(),
|
||||
false,
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
let match_candidates = open_buffers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(id, (full_path, ..))| {
|
||||
let path_string = full_path.as_deref()?.to_string_lossy().to_string();
|
||||
Some(fuzzy::StringMatchCandidate {
|
||||
id,
|
||||
char_bag: path_string.as_str().into(),
|
||||
string: path_string,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let string_matches = fuzzy::match_strings(
|
||||
&match_candidates,
|
||||
&query,
|
||||
true,
|
||||
usize::MAX,
|
||||
&cancel,
|
||||
background_executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(string_matches
|
||||
.into_iter()
|
||||
.filter_map(|string_match| open_buffers.get(string_match.candidate_id))
|
||||
.cloned()
|
||||
.collect())
|
||||
})
|
||||
.await
|
||||
})
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
sections,
|
||||
run_commands_in_text: has_diagnostics,
|
||||
})
|
||||
}),
|
||||
Err(error) => Task::ready(Err(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,17 @@ use crate::DEFAULT_CONTEXT_LINES;
|
||||
|
||||
use super::create_label_for_command;
|
||||
|
||||
pub(crate) struct TerminalSlashCommand;
|
||||
pub(crate) struct TermSlashCommand;
|
||||
|
||||
const LINE_COUNT_ARG: &str = "--line-count";
|
||||
|
||||
impl SlashCommand for TerminalSlashCommand {
|
||||
impl SlashCommand for TermSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"terminal".into()
|
||||
"term".into()
|
||||
}
|
||||
|
||||
fn label(&self, cx: &AppContext) -> CodeLabel {
|
||||
create_label_for_command("terminal", &[LINE_COUNT_ARG], cx)
|
||||
create_label_for_command("term", &[LINE_COUNT_ARG], cx)
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
@@ -45,10 +45,10 @@ impl SlashCommand for TerminalSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(vec![ArgumentCompletion {
|
||||
label: LINE_COUNT_ARG.into(),
|
||||
label: LINE_COUNT_ARG.to_string(),
|
||||
new_text: LINE_COUNT_ARG.to_string(),
|
||||
run_command: true,
|
||||
}]))
|
||||
@@ -7,7 +7,7 @@ use anyhow::Result;
|
||||
use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
};
|
||||
use gpui::{Task, WeakView};
|
||||
use gpui::{AppContext, Task, WeakView};
|
||||
use language::LspAdapterDelegate;
|
||||
use ui::prelude::*;
|
||||
|
||||
@@ -45,7 +45,7 @@ impl SlashCommand for WorkflowSlashCommand {
|
||||
_query: String,
|
||||
_cancel: Arc<AtomicBool>,
|
||||
_workspace: Option<WeakView<Workspace>>,
|
||||
_cx: &mut WindowContext,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
Task::ready(Ok(Vec::new()))
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use gpui::AppContext;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
/// Settings for slash commands.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct SlashCommandSettings {
|
||||
/// Settings for the `/docs` slash command.
|
||||
#[serde(default)]
|
||||
pub docs: DocsCommandSettings,
|
||||
/// Settings for the `/project` slash command.
|
||||
#[serde(default)]
|
||||
pub project: ProjectCommandSettings,
|
||||
}
|
||||
|
||||
/// Settings for the `/docs` slash command.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct DocsCommandSettings {
|
||||
/// Whether `/docs` is enabled.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Settings for the `/project` slash command.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct ProjectCommandSettings {
|
||||
/// Whether `/project` is enabled.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Settings for SlashCommandSettings {
|
||||
const KEY: Option<&'static str> = Some("slash_commands");
|
||||
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _cx: &mut AppContext) -> Result<Self> {
|
||||
SettingsSources::<Self::FileContent>::json_merge_with(
|
||||
[sources.default].into_iter().chain(sources.user),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -276,7 +276,7 @@ impl TerminalInlineAssistant {
|
||||
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![prompt.into()],
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
|
||||
@@ -18,7 +18,7 @@ pub fn init(cx: &mut AppContext) {
|
||||
#[derive(Debug)]
|
||||
pub struct ArgumentCompletion {
|
||||
/// The label to display for this completion.
|
||||
pub label: CodeLabel,
|
||||
pub label: String,
|
||||
/// The new text that should be inserted into the command when this completion is accepted.
|
||||
pub new_text: String,
|
||||
/// Whether the command should be run when accepting this completion.
|
||||
@@ -37,7 +37,7 @@ pub trait SlashCommand: 'static + Send + Sync {
|
||||
query: String,
|
||||
cancel: Arc<AtomicBool>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>>;
|
||||
fn requires_argument(&self) -> bool;
|
||||
fn run(
|
||||
|
||||
@@ -56,14 +56,6 @@ impl SlashCommandRegistry {
|
||||
state.commands.insert(command_name, Arc::new(command));
|
||||
}
|
||||
|
||||
/// Unregisters the provided [`SlashCommand`].
|
||||
pub fn unregister_command(&self, command: impl SlashCommand) {
|
||||
let mut state = self.state.write();
|
||||
let command_name: Arc<str> = command.name().into();
|
||||
state.featured_commands.remove(&command_name);
|
||||
state.commands.remove(&command_name);
|
||||
}
|
||||
|
||||
/// Returns the names of registered [`SlashCommand`]s.
|
||||
pub fn command_names(&self) -> Vec<Arc<str>> {
|
||||
self.state.read().commands.keys().cloned().collect()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Duration;
|
||||
use futures::{stream::BoxStream, StreamExt};
|
||||
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
|
||||
use parking_lot::Mutex;
|
||||
@@ -163,11 +162,6 @@ impl FakeServer {
|
||||
return Ok(*message.downcast().unwrap());
|
||||
}
|
||||
|
||||
let accepted_tos_at = chrono::Utc::now()
|
||||
.checked_sub_signed(Duration::hours(5))
|
||||
.expect("failed to build accepted_tos_at")
|
||||
.timestamp() as u64;
|
||||
|
||||
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
|
||||
self.respond(
|
||||
message
|
||||
@@ -178,7 +172,6 @@ impl FakeServer {
|
||||
metrics_id: "the-metrics-id".into(),
|
||||
staff: false,
|
||||
flags: Default::default(),
|
||||
accepted_tos_at: Some(accepted_tos_at),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::{proto, Client, Status, TypedEnvelope};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{hash_map::Entry, HashMap, HashSet};
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use futures::{channel::mpsc, Future, StreamExt};
|
||||
@@ -95,7 +94,6 @@ pub struct UserStore {
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_plan: Option<proto::Plan>,
|
||||
current_user: watch::Receiver<Option<Arc<User>>>,
|
||||
accepted_tos_at: Option<Option<DateTime<Utc>>>,
|
||||
contacts: Vec<Arc<Contact>>,
|
||||
incoming_contact_requests: Vec<Arc<User>>,
|
||||
outgoing_contact_requests: Vec<Arc<User>>,
|
||||
@@ -152,7 +150,6 @@ impl UserStore {
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
current_plan: None,
|
||||
accepted_tos_at: None,
|
||||
contacts: Default::default(),
|
||||
incoming_contact_requests: Default::default(),
|
||||
participant_indices: Default::default(),
|
||||
@@ -192,10 +189,9 @@ impl UserStore {
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
let fetch_private_user_info =
|
||||
let fetch_metrics_id =
|
||||
client.request(proto::GetPrivateUserInfo {}).log_err();
|
||||
let (user, info) =
|
||||
futures::join!(fetch_user, fetch_private_user_info);
|
||||
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
|
||||
|
||||
cx.update(|cx| {
|
||||
if let Some(info) = info {
|
||||
@@ -206,17 +202,9 @@ impl UserStore {
|
||||
client.telemetry.set_authenticated_user_info(
|
||||
Some(info.metrics_id.clone()),
|
||||
staff,
|
||||
);
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
this.set_current_user_accepted_tos_at(
|
||||
info.accepted_tos_at,
|
||||
);
|
||||
})
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
)
|
||||
}
|
||||
})??;
|
||||
})?;
|
||||
|
||||
current_user_tx.send(user).await.ok();
|
||||
|
||||
@@ -692,39 +680,6 @@ impl UserStore {
|
||||
self.current_user.clone()
|
||||
}
|
||||
|
||||
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
|
||||
self.accepted_tos_at
|
||||
.map(|accepted_tos_at| accepted_tos_at.is_some())
|
||||
}
|
||||
|
||||
pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
if self.current_user().is_none() {
|
||||
return Task::ready(Err(anyhow!("no current user")));
|
||||
};
|
||||
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |this, mut cx| async move {
|
||||
if let Some(client) = client.upgrade() {
|
||||
let response = client
|
||||
.request(proto::AcceptTermsOfService {})
|
||||
.await
|
||||
.context("error accepting tos")?;
|
||||
|
||||
this.update(&mut cx, |this, _| {
|
||||
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at))
|
||||
})
|
||||
} else {
|
||||
Err(anyhow!("client not found"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
|
||||
self.accepted_tos_at = Some(
|
||||
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
|
||||
);
|
||||
}
|
||||
|
||||
fn load_users(
|
||||
&mut self,
|
||||
request: impl RequestMessage<Response = UsersResponse>,
|
||||
|
||||
@@ -58,7 +58,6 @@ serde_derive.workspace = true
|
||||
serde_json.workspace = true
|
||||
sha2.workspace = true
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
|
||||
strum.workspace = true
|
||||
subtle.workspace = true
|
||||
rustc-demangle.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
|
||||
@@ -10,7 +10,7 @@ It contains our back-end logic for collaboration, to which we connect from the Z
|
||||
|
||||
Before you can run the collab server locally, you'll need to set up a zed Postgres database.
|
||||
|
||||
```sh
|
||||
```
|
||||
script/bootstrap
|
||||
```
|
||||
|
||||
@@ -32,13 +32,13 @@ To use a different set of admin users, create `crates/collab/seed.json`.
|
||||
|
||||
In one terminal, run Zed's collaboration server and the livekit dev server:
|
||||
|
||||
```sh
|
||||
```
|
||||
foreman start
|
||||
```
|
||||
|
||||
In a second terminal, run two or more instances of Zed.
|
||||
|
||||
```sh
|
||||
```
|
||||
script/zed-local -2
|
||||
```
|
||||
|
||||
@@ -64,7 +64,7 @@ You can tell what is currently deployed with `./script/what-is-deployed`.
|
||||
|
||||
To create a new migration:
|
||||
|
||||
```sh
|
||||
```
|
||||
./script/create-migration <name>
|
||||
```
|
||||
|
||||
|
||||
@@ -97,13 +97,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: llm-token
|
||||
key: secret
|
||||
- name: LLM_DATABASE_URL
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: llm-database
|
||||
key: url
|
||||
- name: LLM_DATABASE_MAX_CONNECTIONS
|
||||
value: "${LLM_DATABASE_MAX_CONNECTIONS}"
|
||||
- name: ZED_CLIENT_CHECKSUM_SEED
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
@@ -134,11 +127,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: anthropic
|
||||
key: api_key
|
||||
- name: ANTHROPIC_STAFF_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: anthropic
|
||||
key: staff_api_key
|
||||
- name: GOOGLE_AI_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
|
||||
@@ -3,4 +3,3 @@ RUST_LOG=info
|
||||
INVITE_LINK_PREFIX=https://zed.dev/invites/
|
||||
AUTO_JOIN_CHANNEL_ID=283
|
||||
DATABASE_MAX_CONNECTIONS=85
|
||||
LLM_DATABASE_MAX_CONNECTIONS=25
|
||||
|
||||
@@ -2,5 +2,4 @@ ZED_ENVIRONMENT=staging
|
||||
RUST_LOG=info
|
||||
INVITE_LINK_PREFIX=https://staging.zed.dev/invites/
|
||||
DATABASE_MAX_CONNECTIONS=5
|
||||
LLM_DATABASE_MAX_CONNECTIONS=5
|
||||
AUTO_JOIN_CHANNEL_ID=8
|
||||
|
||||
@@ -12,7 +12,7 @@ metadata:
|
||||
spec:
|
||||
type: LoadBalancer
|
||||
selector:
|
||||
app: nginx
|
||||
app: postgrest
|
||||
ports:
|
||||
- name: web
|
||||
protocol: TCP
|
||||
@@ -24,99 +24,17 @@ apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: nginx
|
||||
name: postgrest
|
||||
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: nginx
|
||||
app: postgrest
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: nginx
|
||||
spec:
|
||||
containers:
|
||||
- name: nginx
|
||||
image: nginx:latest
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
protocol: TCP
|
||||
volumeMounts:
|
||||
- name: nginx-config
|
||||
mountPath: /etc/nginx/nginx.conf
|
||||
subPath: nginx.conf
|
||||
volumes:
|
||||
- name: nginx-config
|
||||
configMap:
|
||||
name: nginx-config
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: nginx-config
|
||||
data:
|
||||
nginx.conf: |
|
||||
events {}
|
||||
|
||||
http {
|
||||
server {
|
||||
listen 8080;
|
||||
|
||||
location /app/ {
|
||||
proxy_pass http://postgrest-app:8080/;
|
||||
}
|
||||
|
||||
location /llm/ {
|
||||
proxy_pass http://postgrest-llm:8080/;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: postgrest-app
|
||||
spec:
|
||||
selector:
|
||||
app: postgrest-app
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8080
|
||||
targetPort: 8080
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: postgrest-llm
|
||||
spec:
|
||||
selector:
|
||||
app: postgrest-llm
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8080
|
||||
targetPort: 8080
|
||||
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: postgrest-app
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: postgrest-app
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: postgrest-app
|
||||
app: postgrest
|
||||
spec:
|
||||
containers:
|
||||
- name: postgrest
|
||||
@@ -137,39 +55,3 @@ spec:
|
||||
secretKeyRef:
|
||||
name: postgrest
|
||||
key: jwt_secret
|
||||
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
namespace: ${ZED_KUBE_NAMESPACE}
|
||||
name: postgrest-llm
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: postgrest-llm
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: postgrest-llm
|
||||
spec:
|
||||
containers:
|
||||
- name: postgrest
|
||||
image: "postgrest/postgrest"
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
protocol: TCP
|
||||
env:
|
||||
- name: PGRST_SERVER_PORT
|
||||
value: "8080"
|
||||
- name: PGRST_DB_URI
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: llm-database
|
||||
key: url
|
||||
- name: PGRST_JWT_SECRET
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: postgrest
|
||||
key: jwt_secret
|
||||
|
||||
@@ -9,9 +9,7 @@ CREATE TABLE "users" (
|
||||
"connected_once" BOOLEAN NOT NULL DEFAULT false,
|
||||
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"metrics_id" TEXT,
|
||||
"github_user_id" INTEGER,
|
||||
"accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE,
|
||||
"github_user_created_at" TIMESTAMP WITHOUT TIME ZONE
|
||||
"github_user_id" INTEGER
|
||||
);
|
||||
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
|
||||
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE "users" ADD COLUMN "github_user_created_at" TIMESTAMP WITHOUT TIME ZONE;
|
||||
@@ -0,0 +1,32 @@
|
||||
create table providers (
|
||||
id integer primary key autoincrement,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_providers_on_name on providers (name);
|
||||
|
||||
create table models (
|
||||
id integer primary key autoincrement,
|
||||
provider_id integer not null references providers (id) on delete cascade,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
||||
create index ix_models_on_provider_id on models (provider_id);
|
||||
create index ix_models_on_name on models (name);
|
||||
|
||||
create table if not exists usages (
|
||||
id integer primary key autoincrement,
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
requests_this_minute integer not null default 0,
|
||||
tokens_this_minute integer not null default 0,
|
||||
requests_this_day integer not null default 0,
|
||||
tokens_this_day integer not null default 0,
|
||||
requests_this_month integer not null default 0,
|
||||
tokens_this_month integer not null default 0
|
||||
);
|
||||
|
||||
create index ix_usages_on_user_id on usages (user_id);
|
||||
create index ix_usages_on_model_id on usages (model_id);
|
||||
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
|
||||
@@ -8,10 +8,7 @@ create unique index uix_providers_on_name on providers (name);
|
||||
create table if not exists models (
|
||||
id serial primary key,
|
||||
provider_id integer not null references providers (id) on delete cascade,
|
||||
name text not null,
|
||||
max_requests_per_minute integer not null,
|
||||
max_tokens_per_minute integer not null,
|
||||
max_tokens_per_day integer not null
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_models_on_provider_id_name on models (provider_id, name);
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
create table usage_measures (
|
||||
id serial primary key,
|
||||
name text not null
|
||||
);
|
||||
|
||||
create unique index uix_usage_measures_on_name on usage_measures (name);
|
||||
|
||||
create table if not exists usages (
|
||||
id serial primary key,
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
measure_id integer not null references usage_measures (id) on delete cascade,
|
||||
timestamp timestamp without time zone not null,
|
||||
buckets bigint[] not null
|
||||
requests_this_minute integer not null default 0,
|
||||
tokens_this_minute bigint not null default 0,
|
||||
requests_this_day integer not null default 0,
|
||||
tokens_this_day bigint not null default 0,
|
||||
requests_this_month integer not null default 0,
|
||||
tokens_this_month bigint not null default 0
|
||||
);
|
||||
|
||||
create index ix_usages_on_user_id on usages (user_id);
|
||||
create index ix_usages_on_model_id on usages (model_id);
|
||||
create unique index uix_usages_on_user_id_model_id_measure_id on usages (user_id, model_id, measure_id);
|
||||
create unique index uix_usages_on_user_id_model_id on usages (user_id, model_id);
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN max_requests_per_minute TYPE bigint,
|
||||
ALTER COLUMN max_tokens_per_minute TYPE bigint,
|
||||
ALTER COLUMN max_tokens_per_day TYPE bigint;
|
||||
@@ -1,3 +0,0 @@
|
||||
ALTER TABLE models
|
||||
ADD COLUMN price_per_million_input_tokens integer NOT NULL DEFAULT 0,
|
||||
ADD COLUMN price_per_million_output_tokens integer NOT NULL DEFAULT 0;
|
||||
@@ -1 +0,0 @@
|
||||
alter table usages add column is_staff boolean not null default false;
|
||||
@@ -1,9 +0,0 @@
|
||||
create table lifetime_usages (
|
||||
id serial primary key,
|
||||
user_id integer not null,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
input_tokens bigint not null default 0,
|
||||
output_tokens bigint not null default 0
|
||||
);
|
||||
|
||||
create unique index uix_lifetime_usages_on_user_id_model_id on lifetime_usages (user_id, model_id);
|
||||
@@ -1,7 +0,0 @@
|
||||
create table revoked_access_tokens (
|
||||
id serial primary key,
|
||||
jti text not null,
|
||||
revoked_at timestamp without time zone not null default now()
|
||||
);
|
||||
|
||||
create unique index uix_revoked_access_tokens_on_jti on revoked_access_tokens (jti);
|
||||
@@ -1,4 +0,0 @@
|
||||
db-uri = "postgres://postgres@localhost/zed_llm"
|
||||
server-port = 8082
|
||||
jwt-secret = "the-postgrest-jwt-secret-for-authorization"
|
||||
log-level = "info"
|
||||
@@ -111,7 +111,6 @@ struct AuthenticatedUserParams {
|
||||
github_user_id: Option<i32>,
|
||||
github_login: String,
|
||||
github_email: Option<String>,
|
||||
github_user_created_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -132,7 +131,6 @@ async fn get_authenticated_user(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
params.github_user_created_at,
|
||||
initial_channel_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -115,7 +115,6 @@ async fn add_contributor(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
params.github_user_created_at,
|
||||
initial_channel_id,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -65,7 +65,6 @@ impl Database {
|
||||
github_login: &str,
|
||||
github_user_id: Option<i32>,
|
||||
github_email: Option<&str>,
|
||||
github_user_created_at: Option<DateTimeUtc>,
|
||||
initial_channel_id: Option<ChannelId>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -74,7 +73,6 @@ impl Database {
|
||||
github_login,
|
||||
github_user_id,
|
||||
github_email,
|
||||
github_user_created_at.map(|time| time.naive_utc()),
|
||||
initial_channel_id,
|
||||
&tx,
|
||||
)
|
||||
|
||||
@@ -5,27 +5,15 @@ use util::ResultExt;
|
||||
impl Database {
|
||||
/// Initializes the different kinds of notifications by upserting records for them.
|
||||
pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
|
||||
let all_kinds = Notification::all_variant_names();
|
||||
let existing_kinds = notification_kind::Entity::find().all(&self.pool).await?;
|
||||
|
||||
let kinds_to_create: Vec<_> = all_kinds
|
||||
.iter()
|
||||
.filter(|&kind| {
|
||||
!existing_kinds
|
||||
.iter()
|
||||
.any(|existing| existing.name == **kind)
|
||||
})
|
||||
.map(|kind| notification_kind::ActiveModel {
|
||||
notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
|
||||
|kind| notification_kind::ActiveModel {
|
||||
name: ActiveValue::Set(kind.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !kinds_to_create.is_empty() {
|
||||
notification_kind::Entity::insert_many(kinds_to_create)
|
||||
.exec_without_returning(&self.pool)
|
||||
.await?;
|
||||
}
|
||||
},
|
||||
))
|
||||
.on_conflict(OnConflict::new().do_nothing().to_owned())
|
||||
.exec_without_returning(&self.pool)
|
||||
.await?;
|
||||
|
||||
let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
|
||||
while let Some(row) = rows.next().await {
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use chrono::NaiveDateTime;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl Database {
|
||||
@@ -101,7 +99,6 @@ impl Database {
|
||||
github_login: &str,
|
||||
github_user_id: Option<i32>,
|
||||
github_email: Option<&str>,
|
||||
github_user_created_at: Option<DateTimeUtc>,
|
||||
initial_channel_id: Option<ChannelId>,
|
||||
) -> Result<User> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -109,7 +106,6 @@ impl Database {
|
||||
github_login,
|
||||
github_user_id,
|
||||
github_email,
|
||||
github_user_created_at.map(|created_at| created_at.naive_utc()),
|
||||
initial_channel_id,
|
||||
&tx,
|
||||
)
|
||||
@@ -123,7 +119,6 @@ impl Database {
|
||||
github_login: &str,
|
||||
github_user_id: Option<i32>,
|
||||
github_email: Option<&str>,
|
||||
github_user_created_at: Option<NaiveDateTime>,
|
||||
initial_channel_id: Option<ChannelId>,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<User> {
|
||||
@@ -135,10 +130,6 @@ impl Database {
|
||||
{
|
||||
let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
|
||||
user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
|
||||
if github_user_created_at.is_some() {
|
||||
user_by_github_user_id.github_user_created_at =
|
||||
ActiveValue::set(github_user_created_at);
|
||||
}
|
||||
Ok(user_by_github_user_id.update(tx).await?)
|
||||
} else if let Some(user_by_github_login) = user::Entity::find()
|
||||
.filter(user::Column::GithubLogin.eq(github_login))
|
||||
@@ -147,17 +138,12 @@ impl Database {
|
||||
{
|
||||
let mut user_by_github_login = user_by_github_login.into_active_model();
|
||||
user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
|
||||
if github_user_created_at.is_some() {
|
||||
user_by_github_login.github_user_created_at =
|
||||
ActiveValue::set(github_user_created_at);
|
||||
}
|
||||
Ok(user_by_github_login.update(tx).await?)
|
||||
} else {
|
||||
let user = user::Entity::insert(user::ActiveModel {
|
||||
email_address: ActiveValue::set(github_email.map(|email| email.into())),
|
||||
github_login: ActiveValue::set(github_login.into()),
|
||||
github_user_id: ActiveValue::set(Some(github_user_id)),
|
||||
github_user_created_at: ActiveValue::set(github_user_created_at),
|
||||
admin: ActiveValue::set(false),
|
||||
invite_count: ActiveValue::set(0),
|
||||
invite_code: ActiveValue::set(None),
|
||||
@@ -239,26 +225,6 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
/// Sets "accepted_tos_at" on the user to the given timestamp.
|
||||
pub async fn set_user_accepted_tos_at(
|
||||
&self,
|
||||
id: UserId,
|
||||
accepted_tos_at: Option<DateTime>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
user::Entity::update_many()
|
||||
.filter(user::Column::Id.eq(id))
|
||||
.set(user::ActiveModel {
|
||||
accepted_tos_at: ActiveValue::set(accepted_tos_at),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// hard delete the user.
|
||||
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::db::UserId;
|
||||
use chrono::NaiveDateTime;
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -11,7 +10,6 @@ pub struct Model {
|
||||
pub id: UserId,
|
||||
pub github_login: String,
|
||||
pub github_user_id: Option<i32>,
|
||||
pub github_user_created_at: Option<NaiveDateTime>,
|
||||
pub email_address: Option<String>,
|
||||
pub admin: bool,
|
||||
pub invite_code: Option<String>,
|
||||
@@ -19,8 +17,7 @@ pub struct Model {
|
||||
pub inviter_id: Option<UserId>,
|
||||
pub connected_once: bool,
|
||||
pub metrics_id: Uuid,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub accepted_tos_at: Option<NaiveDateTime>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
||||
@@ -10,7 +10,6 @@ mod extension_tests;
|
||||
mod feature_flag_tests;
|
||||
mod message_tests;
|
||||
mod processed_stripe_event_tests;
|
||||
mod user_tests;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use chrono::Utc;
|
||||
|
||||
use super::Database;
|
||||
use crate::{db::NewUserParams, test_both_dbs};
|
||||
use std::sync::Arc;
|
||||
@@ -24,8 +22,7 @@ async fn test_contributors(db: &Arc<Database>) {
|
||||
|
||||
assert_eq!(db.get_contributors().await.unwrap(), Vec::<String>::new());
|
||||
|
||||
let user1_created_at = Utc::now();
|
||||
db.add_contributor("user1", Some(1), None, Some(user1_created_at), None)
|
||||
db.add_contributor("user1", Some(1), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
@@ -33,8 +30,7 @@ async fn test_contributors(db: &Arc<Database>) {
|
||||
vec!["user1".to_string()]
|
||||
);
|
||||
|
||||
let user2_created_at = Utc::now();
|
||||
db.add_contributor("user2", Some(2), None, Some(user2_created_at), None)
|
||||
db.add_contributor("user2", Some(2), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
use crate::test_both_dbs;
|
||||
use chrono::Utc;
|
||||
use pretty_assertions::{assert_eq, assert_ne};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -101,13 +100,7 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
|
||||
.user_id;
|
||||
|
||||
let user = db
|
||||
.get_or_create_user_by_github_account(
|
||||
"the-new-login2",
|
||||
Some(102),
|
||||
None,
|
||||
Some(Utc::now()),
|
||||
None,
|
||||
)
|
||||
.get_or_create_user_by_github_account("the-new-login2", Some(102), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(user.id, user_id2);
|
||||
@@ -115,13 +108,7 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
|
||||
assert_eq!(user.github_user_id, Some(102));
|
||||
|
||||
let user = db
|
||||
.get_or_create_user_by_github_account(
|
||||
"login3",
|
||||
Some(103),
|
||||
Some("user3@example.com"),
|
||||
Some(Utc::now()),
|
||||
None,
|
||||
)
|
||||
.get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com"), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(&user.github_login, "login3");
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
use chrono::Utc;
|
||||
|
||||
use crate::{
|
||||
db::{Database, NewUserParams},
|
||||
test_both_dbs,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
test_both_dbs!(
|
||||
test_accepted_tos,
|
||||
test_accepted_tos_postgres,
|
||||
test_accepted_tos_sqlite
|
||||
);
|
||||
|
||||
async fn test_accepted_tos(db: &Arc<Database>) {
|
||||
let user_id = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".to_string(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_none());
|
||||
|
||||
let accepted_tos_at = Utc::now().naive_utc();
|
||||
db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_some());
|
||||
assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
|
||||
|
||||
db.set_user_accepted_tos_at(user_id, None).await.unwrap();
|
||||
|
||||
let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
|
||||
assert!(user.accepted_tos_at.is_none());
|
||||
}
|
||||
@@ -166,7 +166,6 @@ pub struct Config {
|
||||
pub openai_api_key: Option<Arc<str>>,
|
||||
pub google_ai_api_key: Option<Arc<str>>,
|
||||
pub anthropic_api_key: Option<Arc<str>>,
|
||||
pub anthropic_staff_api_key: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_key: Option<Arc<str>>,
|
||||
pub qwen2_7b_api_url: Option<Arc<str>>,
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
@@ -217,7 +216,6 @@ impl Config {
|
||||
openai_api_key: None,
|
||||
google_ai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
anthropic_staff_api_key: None,
|
||||
clickhouse_url: None,
|
||||
clickhouse_user: None,
|
||||
clickhouse_password: None,
|
||||
@@ -237,8 +235,7 @@ impl Config {
|
||||
}
|
||||
|
||||
/// The service mode that collab should run in.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum ServiceMode {
|
||||
Api,
|
||||
Collab,
|
||||
|
||||
@@ -1,67 +1,55 @@
|
||||
mod authorization;
|
||||
pub mod db;
|
||||
mod telemetry;
|
||||
mod token;
|
||||
|
||||
use crate::{
|
||||
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor,
|
||||
Config, Error, Result,
|
||||
};
|
||||
use crate::api::CloudflareIpCountryHeader;
|
||||
use crate::llm::authorization::authorize_access_to_language_model;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::{executor::Executor, Config, Error, Result};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use authorization::authorize_access_to_language_model;
|
||||
use axum::TypedHeader;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{self, HeaderName, HeaderValue, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Extension, Json, Router, TypedHeader,
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use futures::StreamExt as _;
|
||||
use http_client::IsahcHttpClient;
|
||||
use rpc::{
|
||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
};
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use telemetry::{report_llm_rate_limit, report_llm_usage, LlmRateLimitEventRow, LlmUsageEventRow};
|
||||
use tokio::sync::RwLock;
|
||||
use util::ResultExt;
|
||||
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
pub struct LlmState {
|
||||
pub config: Config,
|
||||
pub executor: Executor,
|
||||
pub db: Arc<LlmDatabase>,
|
||||
pub db: Option<Arc<LlmDatabase>>,
|
||||
pub http_client: IsahcHttpClient,
|
||||
pub clickhouse_client: Option<clickhouse::Client>,
|
||||
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
|
||||
}
|
||||
|
||||
const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||
|
||||
impl LlmState {
|
||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
// TODO: This is temporary until we have the LLM database stood up.
|
||||
let db = if config.is_development() {
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||
let max_connections = config
|
||||
.llm_database_max_connections
|
||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
db.initialize().await?;
|
||||
let mut db_options = db::ConnectOptions::new(database_url);
|
||||
db_options.max_connections(max_connections);
|
||||
let db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||
|
||||
let db = Arc::new(db);
|
||||
Some(Arc::new(db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||
let http_client = IsahcHttpClient::builder()
|
||||
@@ -69,38 +57,15 @@ impl LlmState {
|
||||
.build()
|
||||
.context("failed to construct http client")?;
|
||||
|
||||
let initial_active_user_count =
|
||||
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
|
||||
|
||||
let this = Self {
|
||||
config,
|
||||
executor,
|
||||
db,
|
||||
http_client,
|
||||
clickhouse_client: config
|
||||
.clickhouse_url
|
||||
.as_ref()
|
||||
.and_then(|_| build_clickhouse_client(&config).log_err()),
|
||||
active_user_count: RwLock::new(initial_active_user_count),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok(Arc::new(this))
|
||||
}
|
||||
|
||||
pub async fn get_active_user_count(&self) -> Result<ActiveUserCount> {
|
||||
let now = Utc::now();
|
||||
|
||||
if let Some((last_updated, count)) = self.active_user_count.read().await.as_ref() {
|
||||
if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION {
|
||||
return Ok(*count);
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = self.active_user_count.write().await;
|
||||
let new_count = self.db.get_active_user_count(now).await?;
|
||||
*cache = Some((now, new_count));
|
||||
Ok(new_count)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes() -> Router<(), Body> {
|
||||
@@ -131,15 +96,6 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
let state = req.extensions().get::<Arc<LlmState>>().unwrap();
|
||||
match LlmTokenClaims::validate(&token, &state.config) {
|
||||
Ok(claims) => {
|
||||
if state.db.is_access_token_revoked(&claims.jti).await? {
|
||||
return Err(Error::http(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"unauthorized".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
tracing::Span::current().record("authn.jti", &claims.jti);
|
||||
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok::<_, Error>(next.run(req).await.into_response())
|
||||
}
|
||||
@@ -166,96 +122,41 @@ async fn perform_completion(
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
Json(params): Json<PerformCompletionParams>,
|
||||
) -> Result<impl IntoResponse> {
|
||||
let model = normalize_model_name(params.provider, params.model);
|
||||
|
||||
authorize_access_to_language_model(
|
||||
&state.config,
|
||||
&claims,
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
params.provider,
|
||||
&model,
|
||||
¶ms.model,
|
||||
)?;
|
||||
|
||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||
|
||||
let stream = match params.provider {
|
||||
match params.provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
let api_key = if claims.is_staff {
|
||||
state
|
||||
.config
|
||||
.anthropic_staff_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI staff API key configured on the server")?
|
||||
} else {
|
||||
state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?
|
||||
};
|
||||
|
||||
let mut request: anthropic::Request =
|
||||
serde_json::from_str(¶ms.provider_request.get())?;
|
||||
|
||||
// Parse the model, throw away the version that was included, and then set a specific
|
||||
// version that we control on the server.
|
||||
// Right now, we use the version that's defined in `model.id()`, but we will likely
|
||||
// want to change this code once a new version of an Anthropic model is released,
|
||||
// so that users can use the new version, without having to update Zed.
|
||||
request.model = match anthropic::Model::from_id(&request.model) {
|
||||
Ok(model) => model.id().to_string(),
|
||||
Err(_) => request.model,
|
||||
};
|
||||
|
||||
let api_key = state
|
||||
.config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let chunks = anthropic::stream_completion(
|
||||
&state.http_client,
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
request,
|
||||
serde_json::from_str(¶ms.provider_request.get())?,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() {
|
||||
Some(anthropic::ApiErrorCode::RateLimitError) => Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Upstream Anthropic rate limit exceeded.".to_string(),
|
||||
),
|
||||
Some(anthropic::ApiErrorCode::InvalidRequestError) => {
|
||||
Error::http(StatusCode::BAD_REQUEST, api_error.message.clone())
|
||||
}
|
||||
Some(anthropic::ApiErrorCode::OverloadedError) => {
|
||||
Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone())
|
||||
}
|
||||
Some(_) => {
|
||||
Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone())
|
||||
}
|
||||
None => Error::Internal(anyhow!(err)),
|
||||
},
|
||||
anthropic::AnthropicError::Other(err) => Error::Internal(err),
|
||||
})?;
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(move |event| {
|
||||
let chunk = event?;
|
||||
let (input_tokens, output_tokens) = match &chunk {
|
||||
anthropic::Event::MessageStart {
|
||||
message: anthropic::Response { usage, .. },
|
||||
}
|
||||
| anthropic::Event::MessageDelta { usage, .. } => (
|
||||
usage.input_tokens.unwrap_or(0) as usize,
|
||||
usage.output_tokens.unwrap_or(0) as usize,
|
||||
),
|
||||
_ => (0, 0),
|
||||
};
|
||||
|
||||
anyhow::Ok((
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
))
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
.boxed()
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {
|
||||
let api_key = state
|
||||
@@ -272,21 +173,17 @@ async fn perform_completion(
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
let input_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||
let output_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||
(
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
})
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
.boxed()
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::Google => {
|
||||
let api_key = state
|
||||
@@ -302,20 +199,17 @@ async fn perform_completion(
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
// TODO - implement token counting for Google AI
|
||||
let input_tokens = 0;
|
||||
let output_tokens = 0;
|
||||
(
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
})
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
.boxed()
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
LanguageModelProvider::Zed => {
|
||||
let api_key = state
|
||||
@@ -337,244 +231,17 @@ async fn perform_completion(
|
||||
)
|
||||
.await?;
|
||||
|
||||
chunks
|
||||
.map(|event| {
|
||||
event.map(|chunk| {
|
||||
let input_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
|
||||
let output_tokens =
|
||||
chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
|
||||
(
|
||||
serde_json::to_vec(&chunk).unwrap(),
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
})
|
||||
let stream = chunks.map(|event| {
|
||||
let mut buffer = Vec::new();
|
||||
event.map(|chunk| {
|
||||
buffer.clear();
|
||||
serde_json::to_writer(&mut buffer, &chunk).unwrap();
|
||||
buffer.push(b'\n');
|
||||
buffer
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
Ok(Response::new(Body::wrap_stream(TokenCountingStream {
|
||||
state,
|
||||
claims,
|
||||
provider: params.provider,
|
||||
model,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
inner_stream: stream,
|
||||
})))
|
||||
}
|
||||
|
||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||
let prefixes: &[_] = match provider {
|
||||
LanguageModelProvider::Anthropic => &[
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
],
|
||||
LanguageModelProvider::OpenAi => &[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
],
|
||||
LanguageModelProvider::Google => &[],
|
||||
LanguageModelProvider::Zed => &[],
|
||||
};
|
||||
|
||||
if let Some(prefix) = prefixes
|
||||
.iter()
|
||||
.filter(|&&prefix| name.starts_with(prefix))
|
||||
.max_by_key(|&&prefix| prefix.len())
|
||||
{
|
||||
prefix.to_string()
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_usage_limit(
|
||||
state: &Arc<LlmState>,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
claims: &LlmTokenClaims,
|
||||
) -> Result<()> {
|
||||
let model = state.db.model(provider, model_name)?;
|
||||
let usage = state
|
||||
.db
|
||||
.get_usage(
|
||||
UserId::from_proto(claims.user_id),
|
||||
provider,
|
||||
model_name,
|
||||
Utc::now(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let active_users = state.get_active_user_count().await?;
|
||||
|
||||
let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1);
|
||||
let users_in_recent_days = active_users.users_in_recent_days.max(1);
|
||||
|
||||
let per_user_max_requests_per_minute =
|
||||
model.max_requests_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_minute =
|
||||
model.max_tokens_per_minute as usize / users_in_recent_minutes;
|
||||
let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days;
|
||||
|
||||
let checks = [
|
||||
(
|
||||
usage.requests_this_minute,
|
||||
per_user_max_requests_per_minute,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_minute,
|
||||
per_user_max_tokens_per_minute,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
),
|
||||
(
|
||||
usage.tokens_this_day,
|
||||
per_user_max_tokens_per_day,
|
||||
UsageMeasure::TokensPerDay,
|
||||
),
|
||||
];
|
||||
|
||||
for (used, limit, usage_measure) in checks {
|
||||
// Temporarily bypass rate-limiting for staff members.
|
||||
if claims.is_staff {
|
||||
continue;
|
||||
}
|
||||
|
||||
if used > limit {
|
||||
let resource = match usage_measure {
|
||||
UsageMeasure::RequestsPerMinute => "requests_per_minute",
|
||||
UsageMeasure::TokensPerMinute => "tokens_per_minute",
|
||||
UsageMeasure::TokensPerDay => "tokens_per_day",
|
||||
_ => "",
|
||||
};
|
||||
|
||||
if let Some(client) = state.clickhouse_client.as_ref() {
|
||||
report_llm_rate_limit(
|
||||
client,
|
||||
LlmRateLimitEventRow {
|
||||
time: Utc::now().timestamp_millis(),
|
||||
user_id: claims.user_id as i32,
|
||||
is_staff: claims.is_staff,
|
||||
plan: match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
model: model.name.clone(),
|
||||
provider: provider.to_string(),
|
||||
usage_measure: resource.to_string(),
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
users_in_recent_minutes: users_in_recent_minutes as u64,
|
||||
users_in_recent_days: users_in_recent_days as u64,
|
||||
max_requests_per_minute: per_user_max_requests_per_minute as u64,
|
||||
max_tokens_per_minute: per_user_max_tokens_per_minute as u64,
|
||||
max_tokens_per_day: per_user_max_tokens_per_day as u64,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
return Err(Error::http(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!("Rate limit exceeded. Maximum {} reached.", resource),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct TokenCountingStream<S> {
|
||||
state: Arc<LlmState>,
|
||||
claims: LlmTokenClaims,
|
||||
provider: LanguageModelProvider,
|
||||
model: String,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
inner_stream: S,
|
||||
}
|
||||
|
||||
impl<S> Stream for TokenCountingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
|
||||
{
|
||||
type Item = Result<Vec<u8>, anyhow::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
|
||||
bytes.push(b'\n');
|
||||
self.input_tokens += input_tokens;
|
||||
self.output_tokens += output_tokens;
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
Ok(Response::new(Body::wrap_stream(stream)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Drop for TokenCountingStream<S> {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let claims = self.claims.clone();
|
||||
let provider = self.provider;
|
||||
let model = std::mem::take(&mut self.model);
|
||||
let input_token_count = self.input_tokens;
|
||||
let output_token_count = self.output_tokens;
|
||||
self.state.executor.spawn_detached(async move {
|
||||
let usage = state
|
||||
.db
|
||||
.record_usage(
|
||||
UserId::from_proto(claims.user_id),
|
||||
claims.is_staff,
|
||||
provider,
|
||||
&model,
|
||||
input_token_count,
|
||||
output_token_count,
|
||||
Utc::now(),
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
|
||||
report_llm_usage(
|
||||
clickhouse_client,
|
||||
LlmUsageEventRow {
|
||||
time: Utc::now().timestamp_millis(),
|
||||
user_id: claims.user_id as i32,
|
||||
is_staff: claims.is_staff,
|
||||
plan: match claims.plan {
|
||||
Plan::Free => "free".to_string(),
|
||||
Plan::ZedPro => "zed_pro".to_string(),
|
||||
},
|
||||
model,
|
||||
provider: provider.to_string(),
|
||||
input_token_count: input_token_count as u64,
|
||||
output_token_count: output_token_count as u64,
|
||||
requests_this_minute: usage.requests_this_minute as u64,
|
||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||
tokens_this_day: usage.tokens_this_day as u64,
|
||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
||||
spending_this_month: usage.spending_this_month as u64,
|
||||
lifetime_spending: usage.lifetime_spending as u64,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
mod ids;
|
||||
mod queries;
|
||||
mod seed;
|
||||
mod tables;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use collections::HashMap;
|
||||
pub use ids::*;
|
||||
use rpc::LanguageModelProvider;
|
||||
pub use seed::*;
|
||||
pub use tables::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
use usage_measure::UsageMeasure;
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
pub use queries::usages::ActiveUserCount;
|
||||
use sea_orm::prelude::*;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::{
|
||||
@@ -37,9 +31,6 @@ pub struct LlmDatabase {
|
||||
pool: DatabaseConnection,
|
||||
#[allow(unused)]
|
||||
executor: Executor,
|
||||
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
|
||||
models: HashMap<(LanguageModelProvider, String), model::Model>,
|
||||
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
|
||||
#[cfg(test)]
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
}
|
||||
@@ -52,28 +43,11 @@ impl LlmDatabase {
|
||||
options: options.clone(),
|
||||
pool: sea_orm::Database::connect(options).await?,
|
||||
executor,
|
||||
provider_ids: HashMap::default(),
|
||||
models: HashMap::default(),
|
||||
usage_measure_ids: HashMap::default(),
|
||||
#[cfg(test)]
|
||||
runtime: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn initialize(&mut self) -> Result<()> {
|
||||
self.initialize_providers().await?;
|
||||
self.initialize_models().await?;
|
||||
self.initialize_usage_measures().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.get(&(provider, name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider:?}:{name}"))?)
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
@@ -6,5 +6,3 @@ use crate::id_type;
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::*;
|
||||
|
||||
pub mod providers;
|
||||
pub mod revoked_access_tokens;
|
||||
pub mod usages;
|
||||
|
||||
@@ -1,122 +1,66 @@
|
||||
use super::*;
|
||||
use sea_orm::sea_query::OnConflict;
|
||||
use sea_orm::QueryOrder;
|
||||
use std::str::FromStr;
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
pub struct ModelParams {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub name: String,
|
||||
pub max_requests_per_minute: i64,
|
||||
pub max_tokens_per_minute: i64,
|
||||
pub max_tokens_per_day: i64,
|
||||
pub price_per_million_input_tokens: i32,
|
||||
pub price_per_million_output_tokens: i32,
|
||||
}
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_providers(&mut self) -> Result<()> {
|
||||
self.provider_ids = self
|
||||
.transaction(|tx| async move {
|
||||
let existing_providers = provider::Entity::find().all(&*tx).await?;
|
||||
|
||||
let mut new_providers = LanguageModelProvider::iter()
|
||||
.filter(|provider| {
|
||||
!existing_providers
|
||||
.iter()
|
||||
.any(|p| p.name == provider.to_string())
|
||||
})
|
||||
.map(|provider| provider::ActiveModel {
|
||||
name: ActiveValue::set(provider.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.peekable();
|
||||
|
||||
if new_providers.peek().is_some() {
|
||||
provider::Entity::insert_many(new_providers)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let all_providers: HashMap<_, _> = provider::Entity::find()
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
LanguageModelProvider::from_str(&provider.name)
|
||||
.ok()
|
||||
.map(|p| (p, provider.id))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(all_providers)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn initialize_models(&mut self) -> Result<()> {
|
||||
let all_provider_ids = &self.provider_ids;
|
||||
self.models = self
|
||||
.transaction(|tx| async move {
|
||||
let all_models: HashMap<_, _> = model::Entity::find()
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|model| {
|
||||
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
|
||||
if *id == model.provider_id {
|
||||
Some(provider)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})?;
|
||||
Some(((*provider, model.name.clone()), model))
|
||||
})
|
||||
.collect();
|
||||
Ok(all_models)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
|
||||
let all_provider_ids = &self.provider_ids;
|
||||
pub async fn initialize_providers(&self) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
model::Entity::insert_many(models.into_iter().map(|model_params| {
|
||||
let provider_id = all_provider_ids[&model_params.provider];
|
||||
model::ActiveModel {
|
||||
provider_id: ActiveValue::set(provider_id),
|
||||
name: ActiveValue::set(model_params.name.clone()),
|
||||
max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
|
||||
max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
|
||||
max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
|
||||
price_per_million_input_tokens: ActiveValue::set(
|
||||
model_params.price_per_million_input_tokens,
|
||||
),
|
||||
price_per_million_output_tokens: ActiveValue::set(
|
||||
model_params.price_per_million_output_tokens,
|
||||
),
|
||||
let providers_and_models = vec![
|
||||
("anthropic", "claude-3-5-sonnet"),
|
||||
("anthropic", "claude-3-opus"),
|
||||
("anthropic", "claude-3-sonnet"),
|
||||
("anthropic", "claude-3-haiku"),
|
||||
];
|
||||
|
||||
for (provider_name, model_name) in providers_and_models {
|
||||
let insert_provider = provider::Entity::insert(provider::ActiveModel {
|
||||
name: ActiveValue::set(provider_name.to_owned()),
|
||||
..Default::default()
|
||||
}
|
||||
}))
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([provider::Column::Name])
|
||||
.update_column(provider::Column::Name)
|
||||
.to_owned(),
|
||||
);
|
||||
|
||||
let provider = if tx.support_returning() {
|
||||
insert_provider.exec_with_returning(&*tx).await?
|
||||
} else {
|
||||
insert_provider.exec_without_returning(&*tx).await?;
|
||||
provider::Entity::find()
|
||||
.filter(provider::Column::Name.eq(provider_name))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("failed to insert provider"))?
|
||||
};
|
||||
|
||||
model::Entity::insert(model::ActiveModel {
|
||||
provider_id: ActiveValue::set(provider.id),
|
||||
name: ActiveValue::set(model_name.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
|
||||
.update_column(model::Column::Name)
|
||||
.to_owned(),
|
||||
)
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
self.initialize_models().await
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the list of LLM providers.
|
||||
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
|
||||
pub async fn list_providers(&self) -> Result<Vec<provider::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(provider::Entity::find()
|
||||
.order_by_asc(provider::Column::Name)
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
|
||||
.collect())
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Returns whether the access token with the given `jti` has been revoked.
|
||||
pub async fn is_access_token_revoked(&self, jti: &str) -> Result<bool> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(revoked_access_token::Entity::find()
|
||||
.filter(revoked_access_token::Column::Jti.eq(jti))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.is_some())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,449 +1,57 @@
|
||||
use crate::db::UserId;
|
||||
use chrono::Duration;
|
||||
use rpc::LanguageModelProvider;
|
||||
use sea_orm::QuerySelect;
|
||||
use std::{iter, str::FromStr};
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||
pub struct Usage {
|
||||
pub requests_this_minute: usize,
|
||||
pub tokens_this_minute: usize,
|
||||
pub tokens_this_day: usize,
|
||||
pub input_tokens_this_month: usize,
|
||||
pub output_tokens_this_month: usize,
|
||||
pub spending_this_month: usize,
|
||||
pub lifetime_spending: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct ActiveUserCount {
|
||||
pub users_in_recent_minutes: usize,
|
||||
pub users_in_recent_days: usize,
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
|
||||
let all_measures = self
|
||||
.transaction(|tx| async move {
|
||||
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
|
||||
|
||||
let new_measures = UsageMeasure::iter()
|
||||
.filter(|measure| {
|
||||
!existing_measures
|
||||
.iter()
|
||||
.any(|m| m.name == measure.to_string())
|
||||
})
|
||||
.map(|measure| usage_measure::ActiveModel {
|
||||
name: ActiveValue::set(measure.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !new_measures.is_empty() {
|
||||
usage_measure::Entity::insert_many(new_measures)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(usage_measure::Entity::find().all(&*tx).await?)
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.usage_measure_ids = all_measures
|
||||
.into_iter()
|
||||
.filter_map(|measure| {
|
||||
UsageMeasure::from_str(&measure.name)
|
||||
.ok()
|
||||
.map(|um| (um, measure.id))
|
||||
})
|
||||
.collect();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_usage(
|
||||
pub async fn find_or_create_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
user_id: i32,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
) -> Result<usage::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self
|
||||
.models
|
||||
.get(&(provider, model_name.to_string()))
|
||||
.ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?;
|
||||
let provider_name = match provider {
|
||||
LanguageModelProvider::Anthropic => "anthropic",
|
||||
LanguageModelProvider::OpenAi => "open_ai",
|
||||
LanguageModelProvider::Google => "google",
|
||||
LanguageModelProvider::Zed => "zed",
|
||||
};
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
let model = model::Entity::find()
|
||||
.inner_join(provider::Entity)
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
provider::Column::Name
|
||||
.eq(provider_name)
|
||||
.and(model::Column::Name.eq(model_name)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.map_or((0, 0), |usage| {
|
||||
(usage.input_tokens as usize, usage.output_tokens as usize)
|
||||
});
|
||||
// TODO: Create the model, if one doesn't exist.
|
||||
.ok_or_else(|| anyhow!("no model found for {provider_name}:{model_name}"))?;
|
||||
let model_id = model.id;
|
||||
|
||||
let requests_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?;
|
||||
let tokens_this_minute =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?;
|
||||
let tokens_this_day =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?;
|
||||
let input_tokens_this_month =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMonth)?;
|
||||
let output_tokens_this_month =
|
||||
self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?;
|
||||
let spending_this_month =
|
||||
calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
|
||||
let lifetime_spending =
|
||||
calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens);
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
input_tokens_this_month,
|
||||
output_tokens_this_month,
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn record_usage(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
provider: LanguageModelProvider,
|
||||
model_name: &str,
|
||||
input_token_count: usize,
|
||||
output_token_count: usize,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Usage> {
|
||||
self.transaction(|tx| async move {
|
||||
let model = self.model(provider, model_name)?;
|
||||
|
||||
let usages = usage::Entity::find()
|
||||
let existing_usage = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(usage::Column::ModelId.eq(model.id)),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let requests_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::RequestsPerMinute,
|
||||
now,
|
||||
1,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_minute = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerMinute,
|
||||
now,
|
||||
input_token_count + output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let tokens_this_day = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::TokensPerDay,
|
||||
now,
|
||||
input_token_count + output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let input_tokens_this_month = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::InputTokensPerMonth,
|
||||
now,
|
||||
input_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_this_month = self
|
||||
.update_usage_for_measure(
|
||||
user_id,
|
||||
is_staff,
|
||||
model.id,
|
||||
&usages,
|
||||
UsageMeasure::OutputTokensPerMonth,
|
||||
now,
|
||||
output_token_count,
|
||||
&tx,
|
||||
)
|
||||
.await?;
|
||||
let spending_this_month =
|
||||
calculate_spending(model, input_tokens_this_month, output_tokens_this_month);
|
||||
|
||||
// Update lifetime usage
|
||||
let lifetime_usage = lifetime_usage::Entity::find()
|
||||
.filter(
|
||||
lifetime_usage::Column::UserId
|
||||
.eq(user_id)
|
||||
.and(lifetime_usage::Column::ModelId.eq(model.id)),
|
||||
.and(usage::Column::ModelId.eq(model_id)),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
if let Some(usage) = existing_usage {
|
||||
return Ok(usage);
|
||||
}
|
||||
|
||||
let lifetime_usage = match lifetime_usage {
|
||||
Some(usage) => {
|
||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||
id: ActiveValue::unchanged(usage.id),
|
||||
input_tokens: ActiveValue::set(
|
||||
usage.input_tokens + input_token_count as i64,
|
||||
),
|
||||
output_tokens: ActiveValue::set(
|
||||
usage.output_tokens + output_token_count as i64,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
lifetime_usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model.id),
|
||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&*tx)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let lifetime_spending = calculate_spending(
|
||||
model,
|
||||
lifetime_usage.input_tokens as usize,
|
||||
lifetime_usage.output_tokens as usize,
|
||||
);
|
||||
|
||||
Ok(Usage {
|
||||
requests_this_minute,
|
||||
tokens_this_minute,
|
||||
tokens_this_day,
|
||||
input_tokens_this_month,
|
||||
output_tokens_this_month,
|
||||
spending_this_month,
|
||||
lifetime_spending,
|
||||
let usage = usage::Entity::insert(usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
model_id: ActiveValue::set(model_id),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(usage)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_user_count(&self, now: DateTimeUtc) -> Result<ActiveUserCount> {
|
||||
self.transaction(|tx| async move {
|
||||
let minute_since = now - Duration::minutes(5);
|
||||
let day_since = now - Duration::days(5);
|
||||
|
||||
let users_in_recent_minutes = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::Timestamp
|
||||
.gte(minute_since.naive_utc())
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
let users_in_recent_days = usage::Entity::find()
|
||||
.filter(
|
||||
usage::Column::Timestamp
|
||||
.gte(day_since.naive_utc())
|
||||
.and(usage::Column::IsStaff.eq(false)),
|
||||
)
|
||||
.select_only()
|
||||
.column(usage::Column::UserId)
|
||||
.group_by(usage::Column::UserId)
|
||||
.count(&*tx)
|
||||
.await? as usize;
|
||||
|
||||
Ok(ActiveUserCount {
|
||||
users_in_recent_minutes,
|
||||
users_in_recent_days,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn update_usage_for_measure(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
model_id: ModelId,
|
||||
usages: &[usage::Model],
|
||||
usage_measure: UsageMeasure,
|
||||
now: DateTimeUtc,
|
||||
usage_to_add: usize,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
|
||||
let mut id = None;
|
||||
let mut timestamp = now;
|
||||
let mut buckets = vec![0_i64];
|
||||
|
||||
if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) {
|
||||
id = Some(old_usage.id);
|
||||
let (live_buckets, buckets_since) =
|
||||
Self::get_live_buckets(old_usage, now, usage_measure);
|
||||
if !live_buckets.is_empty() {
|
||||
buckets.clear();
|
||||
buckets.extend_from_slice(live_buckets);
|
||||
buckets.extend(iter::repeat(0).take(buckets_since));
|
||||
timestamp =
|
||||
old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32);
|
||||
}
|
||||
}
|
||||
|
||||
*buckets.last_mut().unwrap() += usage_to_add as i64;
|
||||
let total_usage = buckets.iter().sum::<i64>() as usize;
|
||||
|
||||
let mut model = usage::ActiveModel {
|
||||
user_id: ActiveValue::set(user_id),
|
||||
is_staff: ActiveValue::set(is_staff),
|
||||
model_id: ActiveValue::set(model_id),
|
||||
measure_id: ActiveValue::set(measure_id),
|
||||
timestamp: ActiveValue::set(timestamp),
|
||||
buckets: ActiveValue::set(buckets),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(id) = id {
|
||||
model.id = ActiveValue::unchanged(id);
|
||||
model.update(tx).await?;
|
||||
} else {
|
||||
usage::Entity::insert(model)
|
||||
.exec_without_returning(tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(total_usage)
|
||||
}
|
||||
|
||||
fn get_usage_for_measure(
|
||||
&self,
|
||||
usages: &[usage::Model],
|
||||
now: DateTimeUtc,
|
||||
usage_measure: UsageMeasure,
|
||||
) -> Result<usize> {
|
||||
let now = now.naive_utc();
|
||||
let measure_id = *self
|
||||
.usage_measure_ids
|
||||
.get(&usage_measure)
|
||||
.ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?;
|
||||
let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else {
|
||||
return Ok(0);
|
||||
};
|
||||
|
||||
let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure);
|
||||
Ok(live_buckets.iter().sum::<i64>() as _)
|
||||
}
|
||||
|
||||
fn get_live_buckets(
|
||||
usage: &usage::Model,
|
||||
now: chrono::NaiveDateTime,
|
||||
measure: UsageMeasure,
|
||||
) -> (&[i64], usize) {
|
||||
let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0);
|
||||
let buckets_since_usage =
|
||||
seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32;
|
||||
let buckets_since_usage = buckets_since_usage.ceil() as usize;
|
||||
let mut live_buckets = &[] as &[i64];
|
||||
if buckets_since_usage < measure.bucket_count() {
|
||||
let expired_bucket_count =
|
||||
(usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count());
|
||||
live_buckets = &usage.buckets[expired_bucket_count..];
|
||||
while live_buckets.first() == Some(&0) {
|
||||
live_buckets = &live_buckets[1..];
|
||||
}
|
||||
}
|
||||
(live_buckets, buckets_since_usage)
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_spending(
|
||||
model: &model::Model,
|
||||
input_tokens_this_month: usize,
|
||||
output_tokens_this_month: usize,
|
||||
) -> usize {
|
||||
let input_token_cost =
|
||||
input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000;
|
||||
let output_token_cost =
|
||||
output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000;
|
||||
input_token_cost + output_token_cost
|
||||
}
|
||||
|
||||
const MINUTE_BUCKET_COUNT: usize = 12;
|
||||
const DAY_BUCKET_COUNT: usize = 48;
|
||||
const MONTH_BUCKET_COUNT: usize = 30;
|
||||
|
||||
impl UsageMeasure {
|
||||
fn bucket_count(&self) -> usize {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerMinute => MINUTE_BUCKET_COUNT,
|
||||
UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT,
|
||||
UsageMeasure::InputTokensPerMonth => MONTH_BUCKET_COUNT,
|
||||
UsageMeasure::OutputTokensPerMonth => MONTH_BUCKET_COUNT,
|
||||
}
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Duration {
|
||||
match self {
|
||||
UsageMeasure::RequestsPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerMinute => Duration::minutes(1),
|
||||
UsageMeasure::TokensPerDay => Duration::hours(24),
|
||||
UsageMeasure::InputTokensPerMonth => Duration::days(30),
|
||||
UsageMeasure::OutputTokensPerMonth => Duration::days(30),
|
||||
}
|
||||
}
|
||||
|
||||
fn bucket_duration(&self) -> Duration {
|
||||
self.total_duration() / self.bucket_count() as i32
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
use super::*;
|
||||
use crate::{Config, Result};
|
||||
use queries::providers::ModelParams;
|
||||
|
||||
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
|
||||
db.insert_models(&[
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-5-sonnet".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 20_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 300, // $3.00/MTok
|
||||
price_per_million_output_tokens: 1500, // $15.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-opus".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 1500, // $15.00/MTok
|
||||
price_per_million_output_tokens: 7500, // $75.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-sonnet".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 20_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 1500, // $15.00/MTok
|
||||
price_per_million_output_tokens: 7500, // $75.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-haiku".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 25_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 25, // $0.25/MTok
|
||||
price_per_million_output_tokens: 125, // $1.25/MTok
|
||||
},
|
||||
])
|
||||
.await
|
||||
}
|
||||
@@ -1,6 +1,3 @@
|
||||
pub mod lifetime_usage;
|
||||
pub mod model;
|
||||
pub mod provider;
|
||||
pub mod revoked_access_token;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
use crate::{db::UserId, llm::db::ModelId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "lifetime_usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -10,11 +10,6 @@ pub struct Model {
|
||||
pub id: ModelId,
|
||||
pub provider_id: ProviderId,
|
||||
pub name: String,
|
||||
pub max_requests_per_minute: i64,
|
||||
pub max_tokens_per_minute: i64,
|
||||
pub max_tokens_per_day: i64,
|
||||
pub price_per_million_input_tokens: i32,
|
||||
pub price_per_million_output_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::llm::db::ProviderId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::ProviderId;
|
||||
|
||||
/// An LLM provider.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "providers")]
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
use chrono::NaiveDateTime;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::RevokedAccessTokenId;
|
||||
|
||||
/// A revoked access token.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "revoked_access_tokens")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: RevokedAccessTokenId,
|
||||
pub jti: String,
|
||||
pub revoked_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,24 +1,24 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{ModelId, UsageId, UsageMeasureId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::ModelId;
|
||||
|
||||
/// An LLM usage record.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: UsageId,
|
||||
pub id: i32,
|
||||
/// The ID of the Zed user.
|
||||
///
|
||||
/// Corresponds to the `users` table in the primary collab database.
|
||||
pub user_id: UserId,
|
||||
pub user_id: i32,
|
||||
pub model_id: ModelId,
|
||||
pub measure_id: UsageMeasureId,
|
||||
pub timestamp: DateTime,
|
||||
pub buckets: Vec<i64>,
|
||||
pub is_staff: bool,
|
||||
pub requests_this_minute: i32,
|
||||
pub tokens_this_minute: i64,
|
||||
pub requests_this_day: i32,
|
||||
pub tokens_this_day: i64,
|
||||
pub requests_this_month: i32,
|
||||
pub tokens_this_month: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
@@ -29,12 +29,6 @@ pub enum Relation {
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::usage_measure::Entity",
|
||||
from = "Column::MeasureId",
|
||||
to = "super::usage_measure::Column::Id"
|
||||
)]
|
||||
UsageMeasure,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
@@ -43,10 +37,4 @@ impl Related<super::model::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::usage_measure::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::UsageMeasure.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
use crate::llm::db::UsageMeasureId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(
|
||||
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
|
||||
)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum UsageMeasure {
|
||||
RequestsPerMinute,
|
||||
TokensPerMinute,
|
||||
TokensPerDay,
|
||||
InputTokensPerMonth,
|
||||
OutputTokensPerMonth,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "usage_measures")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: UsageMeasureId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
}
|
||||
|
||||
impl Related<super::usage::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Usages.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -6,6 +6,7 @@ use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use sea_orm::ConnectionTrait;
|
||||
use sqlx::migrate::MigrateDatabase;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
@@ -13,11 +14,47 @@ use crate::migrations::run_database_migrations;
|
||||
use super::*;
|
||||
|
||||
pub struct TestLlmDb {
|
||||
pub db: Option<LlmDatabase>,
|
||||
pub db: Option<Arc<LlmDatabase>>,
|
||||
pub connection: Option<sqlx::AnyConnection>,
|
||||
}
|
||||
|
||||
impl TestLlmDb {
|
||||
pub fn sqlite(background: BackgroundExecutor) -> Self {
|
||||
let url = "sqlite::memory:";
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut db = runtime.block_on(async {
|
||||
let mut options = ConnectOptions::new(url);
|
||||
options.max_connections(5);
|
||||
let db = LlmDatabase::new(options, Executor::Deterministic(background))
|
||||
.await
|
||||
.unwrap();
|
||||
let sql = include_str!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/migrations_llm.sqlite/20240806182921_test_schema.sql"
|
||||
));
|
||||
db.pool
|
||||
.execute(sea_orm::Statement::from_string(
|
||||
db.pool.get_database_backend(),
|
||||
sql,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
db
|
||||
});
|
||||
|
||||
db.runtime = Some(runtime);
|
||||
|
||||
Self {
|
||||
db: Some(Arc::new(db)),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn postgres(background: BackgroundExecutor) -> Self {
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
@@ -54,26 +91,29 @@ impl TestLlmDb {
|
||||
db.runtime = Some(runtime);
|
||||
|
||||
Self {
|
||||
db: Some(db),
|
||||
db: Some(Arc::new(db)),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn db(&mut self) -> &mut LlmDatabase {
|
||||
self.db.as_mut().unwrap()
|
||||
pub fn db(&self) -> &Arc<LlmDatabase> {
|
||||
self.db.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_llm_db {
|
||||
($test_name:ident, $postgres_test_name:ident) => {
|
||||
macro_rules! test_both_llm_dbs {
|
||||
($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
|
||||
#[cfg(target_os = "macos")]
|
||||
#[gpui::test]
|
||||
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
|
||||
if !cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
||||
$test_name(test_db.db()).await;
|
||||
}
|
||||
|
||||
let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
||||
#[gpui::test]
|
||||
async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
|
||||
let test_db = $crate::llm::db::TestLlmDb::sqlite(cx.executor().clone());
|
||||
$test_name(test_db.db()).await;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_llm_db;
|
||||
use crate::test_both_llm_dbs;
|
||||
|
||||
test_llm_db!(
|
||||
test_both_llm_dbs!(
|
||||
test_initialize_providers,
|
||||
test_initialize_providers_postgres
|
||||
test_initialize_providers_postgres,
|
||||
test_initialize_providers_sqlite
|
||||
);
|
||||
|
||||
async fn test_initialize_providers(db: &mut LlmDatabase) {
|
||||
async fn test_initialize_providers(db: &Arc<LlmDatabase>) {
|
||||
let initial_providers = db.list_providers().await.unwrap();
|
||||
assert_eq!(initial_providers, vec![]);
|
||||
|
||||
@@ -20,13 +22,9 @@ async fn test_initialize_providers(db: &mut LlmDatabase) {
|
||||
|
||||
let providers = db.list_providers().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
providers,
|
||||
&[
|
||||
LanguageModelProvider::Anthropic,
|
||||
LanguageModelProvider::Google,
|
||||
LanguageModelProvider::OpenAi,
|
||||
LanguageModelProvider::Zed
|
||||
]
|
||||
)
|
||||
let provider_names = providers
|
||||
.into_iter()
|
||||
.map(|provider| provider.name)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(provider_names, vec!["anthropic".to_string()]);
|
||||
}
|
||||
|
||||
@@ -1,142 +1,24 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{
|
||||
queries::{providers::ModelParams, usages::Usage},
|
||||
LlmDatabase,
|
||||
},
|
||||
test_llm_db,
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::LanguageModelProvider;
|
||||
|
||||
test_llm_db!(test_tracking_usage, test_tracking_usage_postgres);
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_both_llm_dbs;
|
||||
|
||||
async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||
let provider = LanguageModelProvider::Anthropic;
|
||||
let model = "claude-3-5-sonnet";
|
||||
test_both_llm_dbs!(
|
||||
test_find_or_create_usage,
|
||||
test_find_or_create_usage_postgres,
|
||||
test_find_or_create_usage_sqlite
|
||||
);
|
||||
|
||||
db.initialize().await.unwrap();
|
||||
db.insert_models(&[ModelParams {
|
||||
provider,
|
||||
name: model.to_string(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 50_000,
|
||||
price_per_million_input_tokens: 50,
|
||||
price_per_million_output_tokens: 50,
|
||||
}])
|
||||
.await
|
||||
.unwrap();
|
||||
async fn test_find_or_create_usage(db: &Arc<LlmDatabase>) {
|
||||
db.initialize_providers().await.unwrap();
|
||||
|
||||
let t0 = Utc::now();
|
||||
let user_id = UserId::from_proto(123);
|
||||
|
||||
let now = t0;
|
||||
db.record_usage(user_id, false, provider, model, 1000, 0, now)
|
||||
let usage = db
|
||||
.find_or_create_usage(123, LanguageModelProvider::Anthropic, "claude-3-5-sonnet")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let now = t0 + Duration::seconds(10);
|
||||
db.record_usage(user_id, false, provider, model, 2000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 3000,
|
||||
tokens_this_day: 3000,
|
||||
input_tokens_this_month: 3000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 2000,
|
||||
tokens_this_day: 3000,
|
||||
input_tokens_this_month: 3000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let now = t0 + Duration::seconds(60);
|
||||
db.record_usage(user_id, false, provider, model, 3000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 2,
|
||||
tokens_this_minute: 5000,
|
||||
tokens_this_day: 6000,
|
||||
input_tokens_this_month: 6000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let t1 = t0 + Duration::hours(24);
|
||||
let now = t1;
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 5000,
|
||||
input_tokens_this_month: 6000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
db.record_usage(user_id, false, provider, model, 4000, 0, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 1,
|
||||
tokens_this_minute: 4000,
|
||||
tokens_this_day: 9000,
|
||||
input_tokens_this_month: 10000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
|
||||
let t2 = t0 + Duration::days(30);
|
||||
let now = t2;
|
||||
let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
|
||||
assert_eq!(
|
||||
usage,
|
||||
Usage {
|
||||
requests_this_minute: 0,
|
||||
tokens_this_minute: 0,
|
||||
tokens_this_day: 0,
|
||||
input_tokens_this_month: 9000,
|
||||
output_tokens_this_month: 0,
|
||||
spending_this_month: 0,
|
||||
lifetime_spending: 0,
|
||||
}
|
||||
);
|
||||
assert_eq!(usage.user_id, 123);
|
||||
}
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct LlmUsageEventRow {
|
||||
pub time: i64,
|
||||
pub user_id: i32,
|
||||
pub is_staff: bool,
|
||||
pub plan: String,
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub input_token_count: u64,
|
||||
pub output_token_count: u64,
|
||||
pub requests_this_minute: u64,
|
||||
pub tokens_this_minute: u64,
|
||||
pub tokens_this_day: u64,
|
||||
pub input_tokens_this_month: u64,
|
||||
pub output_tokens_this_month: u64,
|
||||
pub spending_this_month: u64,
|
||||
pub lifetime_spending: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct LlmRateLimitEventRow {
|
||||
pub time: i64,
|
||||
pub user_id: i32,
|
||||
pub is_staff: bool,
|
||||
pub plan: String,
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub usage_measure: String,
|
||||
pub requests_this_minute: u64,
|
||||
pub tokens_this_minute: u64,
|
||||
pub tokens_this_day: u64,
|
||||
pub users_in_recent_minutes: u64,
|
||||
pub users_in_recent_days: u64,
|
||||
pub max_requests_per_minute: u64,
|
||||
pub max_tokens_per_minute: u64,
|
||||
pub max_tokens_per_day: u64,
|
||||
}
|
||||
|
||||
pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
|
||||
let mut insert = client.insert("llm_usage_events")?;
|
||||
insert.write(&row).await?;
|
||||
insert.end().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn report_llm_rate_limit(
|
||||
client: &clickhouse::Client,
|
||||
row: LlmRateLimitEventRow,
|
||||
) -> Result<()> {
|
||||
let mut insert = client.insert("llm_rate_limits")?;
|
||||
insert.write(&row).await?;
|
||||
insert.end().await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -52,18 +52,10 @@ async fn main() -> Result<()> {
|
||||
Some("seed") => {
|
||||
let config = envy::from_env::<Config>().expect("error loading config");
|
||||
let db_options = db::ConnectOptions::new(config.database_url.clone());
|
||||
|
||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
collab::seed::seed(&config, &db, false).await?;
|
||||
|
||||
if let Some(llm_database_url) = config.llm_database_url.clone() {
|
||||
let db_options = db::ConnectOptions::new(llm_database_url);
|
||||
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
|
||||
db.initialize().await?;
|
||||
collab::llm::db::seed_database(&config, &mut db, true).await?;
|
||||
}
|
||||
collab::seed::seed(&config, &db, true).await?;
|
||||
}
|
||||
Some("serve") => {
|
||||
let mode = match args.next().as_deref() {
|
||||
@@ -150,7 +142,6 @@ async fn main() -> Result<()> {
|
||||
"http_request",
|
||||
method = ?request.method(),
|
||||
matched_path,
|
||||
authn.jti = tracing::field::Empty
|
||||
)
|
||||
})
|
||||
.on_response(
|
||||
@@ -249,6 +240,11 @@ async fn setup_app_database(config: &Config) -> Result<()> {
|
||||
}
|
||||
|
||||
async fn setup_llm_database(config: &Config) -> Result<()> {
|
||||
// TODO: This is temporary until we have the LLM database stood up.
|
||||
if !config.is_development() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let database_url = config
|
||||
.llm_database_url
|
||||
.as_ref()
|
||||
@@ -283,7 +279,10 @@ async fn setup_llm_database(config: &Config) -> Result<()> {
|
||||
}
|
||||
|
||||
async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
|
||||
format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown"))
|
||||
format!(
|
||||
"collab {mode:?} v{VERSION} ({})",
|
||||
REVISION.unwrap_or("unknown")
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_liveness_probe(
|
||||
@@ -294,12 +293,7 @@ async fn handle_liveness_probe(
|
||||
state.db.get_all_users(0, 1).await?;
|
||||
}
|
||||
|
||||
if let Some(llm_state) = llm_state {
|
||||
llm_state
|
||||
.db
|
||||
.get_active_user_count(chrono::Utc::now())
|
||||
.await?;
|
||||
}
|
||||
if let Some(_llm_state) = llm_state {}
|
||||
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ use axum::{
|
||||
routing::get,
|
||||
Extension, Router, TypedHeader,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
@@ -106,6 +105,18 @@ impl<R: RequestMessage> Response<R> {
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingResponse<R: RequestMessage> {
|
||||
peer: Arc<Peer>,
|
||||
receipt: Receipt<R>,
|
||||
}
|
||||
|
||||
impl<R: RequestMessage> StreamingResponse<R> {
|
||||
fn send(&self, payload: R::Response) -> Result<()> {
|
||||
self.peer.respond(self.receipt, payload)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Principal {
|
||||
User(User),
|
||||
@@ -605,7 +616,6 @@ impl Server {
|
||||
.add_message_handler(user_message_handler(update_followers))
|
||||
.add_request_handler(user_handler(get_private_user_info))
|
||||
.add_request_handler(user_handler(get_llm_api_token))
|
||||
.add_request_handler(user_handler(accept_terms_of_service))
|
||||
.add_message_handler(user_message_handler(acknowledge_channel_message))
|
||||
.add_message_handler(user_message_handler(acknowledge_buffer_version))
|
||||
.add_request_handler(user_handler(get_supermaven_api_key))
|
||||
@@ -620,6 +630,31 @@ impl Server {
|
||||
))
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
complete_with_language_model(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_streaming_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
stream_complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
&app_state.config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
@@ -913,6 +948,40 @@ impl Server {
|
||||
})
|
||||
}
|
||||
|
||||
fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
|
||||
Fut: Send + Future<Output = Result<()>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
let handler = Arc::new(handler);
|
||||
self.add_handler(move |envelope, session| {
|
||||
let receipt = envelope.receipt();
|
||||
let handler = handler.clone();
|
||||
async move {
|
||||
let peer = session.peer.clone();
|
||||
let response = StreamingResponse {
|
||||
peer: peer.clone(),
|
||||
receipt,
|
||||
};
|
||||
match (handler)(envelope.payload, response, session).await {
|
||||
Ok(()) => {
|
||||
peer.end_stream(receipt)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let proto_err = match &error {
|
||||
Error::Internal(err) => err.to_proto(),
|
||||
_ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
|
||||
};
|
||||
peer.respond_with_error(receipt, proto_err)?;
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn handle_connection(
|
||||
self: &Arc<Self>,
|
||||
@@ -4492,6 +4561,202 @@ async fn acknowledge_buffer_version(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ZedProCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"zed-pro:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
struct FreeCompleteWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
|
||||
fn capacity(&self) -> usize {
|
||||
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(120 / 10) // Picked arbitrarily
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> chrono::Duration {
|
||||
chrono::Duration::hours(1)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"free:complete-with-language-model"
|
||||
}
|
||||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: Response<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
anthropic::complete(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
|
||||
response.send(proto::CompleteWithLanguageModelResponse {
|
||||
completion: serde_json::to_string(&result)?,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_complete_with_language_model(
|
||||
request: proto::StreamCompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
|
||||
proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
|
||||
};
|
||||
|
||||
session
|
||||
.app_state
|
||||
.rate_limiter
|
||||
.check(&*rate_limit, session.user_id())
|
||||
.await?;
|
||||
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let mut chunks = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = chunks.next().await {
|
||||
let chunk = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&chunk)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
||||
let api_key = config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let mut events = google_ai::stream_generate_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Zed) => {
|
||||
let api_key = config
|
||||
.qwen2_7b_api_key
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B API key configured on the server")?;
|
||||
let api_url = config
|
||||
.qwen2_7b_api_url
|
||||
.as_ref()
|
||||
.context("no Qwen2-7B URL configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
&api_url,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
None => return Err(anyhow!("unknown provider"))?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn count_language_model_tokens(
|
||||
request: proto::CountLanguageModelTokens,
|
||||
response: Response<proto::CountLanguageModelTokens>,
|
||||
@@ -4501,7 +4766,7 @@ async fn count_language_model_tokens(
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
|
||||
@@ -4621,7 +4886,7 @@ async fn compute_embeddings(
|
||||
api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
let api_key = api_key.context("no OpenAI API key configured on the server")?;
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
|
||||
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
|
||||
@@ -4685,7 +4950,7 @@ async fn get_cached_embeddings(
|
||||
response: Response<proto::GetCachedEmbeddings>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_legacy_llm_endpoints(&session).await?;
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
let db = session.db().await;
|
||||
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
|
||||
@@ -4699,15 +4964,14 @@ async fn get_cached_embeddings(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is leftover from before the LLM service.
|
||||
///
|
||||
/// The endpoints protected by this check will be moved there eventually.
|
||||
async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
|
||||
if session.is_staff() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("permission denied"))?
|
||||
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
|
||||
let db = session.db().await;
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if flags.iter().any(|flag| flag == "language-models") {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
/// Get a Supermaven API key for the user
|
||||
@@ -4885,64 +5149,21 @@ async fn get_private_user_info(
|
||||
metrics_id,
|
||||
staff: user.admin,
|
||||
flags,
|
||||
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Accept the terms of service (tos) on behalf of the current user
|
||||
async fn accept_terms_of_service(
|
||||
_request: proto::AcceptTermsOfService,
|
||||
response: Response<proto::AcceptTermsOfService>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let accepted_tos_at = Utc::now();
|
||||
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
|
||||
.await?;
|
||||
|
||||
response.send(proto::AcceptTermsOfServiceResponse {
|
||||
accepted_tos_at: accepted_tos_at.timestamp() as u64,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The minimum account age an account must have in order to use the LLM service.
|
||||
const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
|
||||
|
||||
async fn get_llm_api_token(
|
||||
_request: proto::GetLlmToken,
|
||||
response: Response<proto::GetLlmToken>,
|
||||
session: UserSession,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let flags = db.get_user_flags(session.user_id()).await?;
|
||||
if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
|
||||
if !session.is_staff() {
|
||||
Err(anyhow!("permission denied"))?
|
||||
}
|
||||
|
||||
let user_id = session.user_id();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user {} not found", user_id))?;
|
||||
|
||||
if user.accepted_tos_at.is_none() {
|
||||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let mut account_created_at = user.created_at;
|
||||
if let Some(github_created_at) = user.github_user_created_at {
|
||||
account_created_at = account_created_at.min(github_created_at);
|
||||
}
|
||||
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
|
||||
Err(anyhow!("account too young"))?
|
||||
}
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
user.id,
|
||||
session.user_id(),
|
||||
session.is_staff(),
|
||||
session.current_plan().await?,
|
||||
&session.app_state.config,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::db::{self, ChannelRole, NewUserParams};
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Utc};
|
||||
use db::Database;
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use std::{fmt::Write, fs, path::Path};
|
||||
@@ -13,7 +12,6 @@ struct GitHubUser {
|
||||
id: i32,
|
||||
login: String,
|
||||
email: Option<String>,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -109,7 +107,6 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
|
||||
&github_user.login,
|
||||
Some(github_user.id),
|
||||
github_user.email.as_deref(),
|
||||
Some(github_user.created_at),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
||||