Compare commits
104 Commits
debugger-i
...
suggest-ne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33d2c8f726 | ||
|
|
1e160f22ce | ||
|
|
fdb6f90abb | ||
|
|
d473c6892d | ||
|
|
a036b64fb9 | ||
|
|
5fa3a8256c | ||
|
|
6a9951769d | ||
|
|
771a7bb4b6 | ||
|
|
49f87165ea | ||
|
|
66ada3e44c | ||
|
|
d6bdaa8a91 | ||
|
|
e0fe7f632c | ||
|
|
fee49fcf65 | ||
|
|
58755a6c88 | ||
|
|
04e25525bf | ||
|
|
3a44a59f8e | ||
|
|
53a3d6424f | ||
|
|
0171b7d53c | ||
|
|
26aec4ba99 | ||
|
|
cb07e02ce9 | ||
|
|
b05d532991 | ||
|
|
f88278111e | ||
|
|
f1d777434b | ||
|
|
c7a78fafac | ||
|
|
b8982ad385 | ||
|
|
1ffb34c5fc | ||
|
|
8b22f09b6f | ||
|
|
e72f33d79e | ||
|
|
af24967195 | ||
|
|
acea6f9c0f | ||
|
|
4976a9e9d8 | ||
|
|
138c3fcfdd | ||
|
|
64add2f222 | ||
|
|
70e895a8c7 | ||
|
|
4bd935b409 | ||
|
|
c0df1e1846 | ||
|
|
e9d0768e3c | ||
|
|
380a99038b | ||
|
|
88653c4e3e | ||
|
|
3751f67730 | ||
|
|
6af385c09e | ||
|
|
e6cd1cf22b | ||
|
|
a1bd7a1297 | ||
|
|
3e31955b7f | ||
|
|
be86852f95 | ||
|
|
bde02a350e | ||
|
|
4c9311ba40 | ||
|
|
c8bc49fa18 | ||
|
|
bcd972fbb4 | ||
|
|
e423f03ba6 | ||
|
|
03ebbcbef6 | ||
|
|
27f97ba762 | ||
|
|
769ae8b101 | ||
|
|
d27fef7b2c | ||
|
|
f4bbbe69b4 | ||
|
|
c937a2fcdd | ||
|
|
a5279cc48a | ||
|
|
4d56252bae | ||
|
|
0360cda543 | ||
|
|
5e04753d1c | ||
|
|
71312e5692 | ||
|
|
05825e9804 | ||
|
|
73d682c010 | ||
|
|
e59e47fe7f | ||
|
|
4abf7f058e | ||
|
|
f980e40993 | ||
|
|
57b2cb6f60 | ||
|
|
af014a2530 | ||
|
|
243fb3562c | ||
|
|
e830865eb1 | ||
|
|
7aa6f4788d | ||
|
|
18daf17d0e | ||
|
|
856d9632e4 | ||
|
|
745d2e4d3b | ||
|
|
50dbab0747 | ||
|
|
70c22cbdd6 | ||
|
|
9621005851 | ||
|
|
05003ed4c5 | ||
|
|
2c610c0e57 | ||
|
|
479ffbbd51 | ||
|
|
fe23504eba | ||
|
|
95d82f88de | ||
|
|
4000b0a02c | ||
|
|
02c43a5bf2 | ||
|
|
f2060ccbe0 | ||
|
|
13693ff80f | ||
|
|
ec5886a078 | ||
|
|
10c9e337cf | ||
|
|
1da6a12bb4 | ||
|
|
cc1d3f0a35 | ||
|
|
22118f15e9 | ||
|
|
0d5de88c4b | ||
|
|
f291677d40 | ||
|
|
9d736fe80c | ||
|
|
f3ad754396 | ||
|
|
86456ce379 | ||
|
|
d755d29577 | ||
|
|
ab3c9f0678 | ||
|
|
201db23b58 | ||
|
|
beb8fbdf7f | ||
|
|
d2501e8886 | ||
|
|
82d6ad4616 | ||
|
|
a60b3b9389 | ||
|
|
06863144c6 |
@@ -12,3 +12,7 @@ rustflags = ["-C", "link-arg=-fuse-ld=mold"]
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "clang"
|
||||
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
|
||||
|
||||
# This cfg will reduce the size of `windows::core::Error` from 16 bytes to 4 bytes
|
||||
[target.'cfg(target_os = "windows")']
|
||||
rustflags = ["--cfg", "windows_slim_errors"]
|
||||
|
||||
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -147,7 +147,8 @@ jobs:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: cargo clippy
|
||||
run: ./script/clippy
|
||||
# Windows can't run shell scripts, so we need to use `cargo xtask`.
|
||||
run: cargo xtask clippy
|
||||
|
||||
- name: Build Zed
|
||||
run: cargo build -p zed
|
||||
|
||||
2
.github/workflows/danger.yml
vendored
2
.github/workflows/danger.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@v3
|
||||
- uses: pnpm/action-setup@a3252b78c470c02df07e9d59298aecedc3ccdd6d # v3
|
||||
with:
|
||||
version: 9
|
||||
|
||||
|
||||
24
.github/workflows/docs.yml
vendored
Normal file
24
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Docs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "docs/**"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check_formatting:
|
||||
name: "Check formatting"
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4
|
||||
|
||||
- uses: pnpm/action-setup@a3252b78c470c02df07e9d59298aecedc3ccdd6d # v3
|
||||
with:
|
||||
version: 9
|
||||
|
||||
- run: pnpm dlx prettier . --check
|
||||
working-directory: ./docs
|
||||
@@ -26,6 +26,10 @@
|
||||
"tab_size": 2,
|
||||
"formatter": "prettier"
|
||||
},
|
||||
"CSS": {
|
||||
"tab_size": 2,
|
||||
"formatter": "prettier"
|
||||
},
|
||||
"Rust": {
|
||||
"tasks": {
|
||||
"variables": {
|
||||
|
||||
1084
Cargo.lock
generated
1084
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
65
Cargo.toml
65
Cargo.toml
@@ -1,11 +1,11 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/anthropic",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/assistant_slash_command",
|
||||
"crates/assistant_tooling",
|
||||
"crates/audio",
|
||||
"crates/auto_update",
|
||||
"crates/breadcrumbs",
|
||||
@@ -125,6 +125,10 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
|
||||
#
|
||||
# Extensions
|
||||
#
|
||||
|
||||
"extensions/astro",
|
||||
"extensions/clojure",
|
||||
"extensions/csharp",
|
||||
@@ -154,20 +158,25 @@ members = [
|
||||
"extensions/vue",
|
||||
"extensions/zig",
|
||||
|
||||
#
|
||||
# Tooling
|
||||
#
|
||||
|
||||
"tooling/xtask",
|
||||
]
|
||||
default-members = ["crates/zed"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
#
|
||||
# Workspace member crates
|
||||
#
|
||||
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
aho-corasick = "1.1"
|
||||
ai = { path = "crates/ai" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
assistant_tooling = { path = "crates/assistant_tooling" }
|
||||
audio = { path = "crates/audio" }
|
||||
auto_update = { path = "crates/auto_update" }
|
||||
breadcrumbs = { path = "crates/breadcrumbs" }
|
||||
@@ -240,6 +249,7 @@ project_symbols = { path = "crates/project_symbols" }
|
||||
proto = { path = "crates/proto" }
|
||||
quick_action_bar = { path = "crates/quick_action_bar" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
refineable = { path = "crates/refineable" }
|
||||
release_channel = { path = "crates/release_channel" }
|
||||
remote = { path = "crates/remote" }
|
||||
remote_server = { path = "crates/remote_server" }
|
||||
@@ -285,39 +295,44 @@ worktree = { path = "crates/worktree" }
|
||||
zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
|
||||
#
|
||||
# External crates
|
||||
#
|
||||
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = "0.23"
|
||||
any_vec = "0.14"
|
||||
anyhow = "1.0.86"
|
||||
ashpd = "0.9.1"
|
||||
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
|
||||
async-dispatcher = { version = "0.1" }
|
||||
async-dispatcher = "0.1"
|
||||
async-fs = "1.6"
|
||||
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.4.2"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = { version = "0.16" }
|
||||
async-tungstenite = "0.23"
|
||||
async-watch = "0.3.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
base64 = "0.13"
|
||||
base64 = "0.22"
|
||||
bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
|
||||
blade-macros = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
|
||||
blade-util = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
|
||||
cap-std = "3.0"
|
||||
cargo_metadata = "0.18"
|
||||
cargo_toml = "0.20"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
clickhouse = { version = "0.11.6" }
|
||||
clickhouse = "0.11.6"
|
||||
cocoa = "0.25"
|
||||
core-foundation = { version = "0.9.3" }
|
||||
core-foundation = "0.9.3"
|
||||
core-foundation-sys = "0.8.6"
|
||||
ctor = "0.2.6"
|
||||
dashmap = "5.5.3"
|
||||
dashmap = "6.0"
|
||||
derive_more = "0.99.17"
|
||||
dirs = "4.0"
|
||||
emojis = "0.6.1"
|
||||
env_logger = "0.10"
|
||||
env_logger = "0.11"
|
||||
exec = "0.3.1"
|
||||
fork = "0.1.23"
|
||||
futures = "0.3"
|
||||
@@ -331,12 +346,13 @@ html5ever = "0.27.0"
|
||||
ignore = "0.4.22"
|
||||
image = "0.25.1"
|
||||
indexmap = { version = "1.6.2", features = ["serde"] }
|
||||
indoc = "1"
|
||||
indoc = "2"
|
||||
# We explicitly disable http2 support in isahc.
|
||||
isahc = { version = "1.7.2", default-features = false, features = [
|
||||
"text-decoding",
|
||||
] }
|
||||
itertools = "0.11.0"
|
||||
jsonwebtoken = "9.3"
|
||||
lazy_static = "1.4.0"
|
||||
libc = "0.2"
|
||||
linkify = "0.10.0"
|
||||
@@ -358,7 +374,6 @@ prost-build = "0.9"
|
||||
prost-types = "0.9"
|
||||
pulldown-cmark = { version = "0.10.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
refineable = { path = "./crates/refineable" }
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
rsa = "0.9.6"
|
||||
@@ -388,6 +403,7 @@ smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "1.2"
|
||||
strum = { version = "0.25.0", features = ["derive"] }
|
||||
subtle = "2.5.0"
|
||||
sys-locale = "0.3.1"
|
||||
sysinfo = "0.30.7"
|
||||
tempfile = "3.9.0"
|
||||
thiserror = "1.0.29"
|
||||
@@ -411,7 +427,7 @@ tree-sitter-css = "0.21"
|
||||
tree-sitter-elixir = "0.2"
|
||||
tree-sitter-embedded-template = "0.20.0"
|
||||
tree-sitter-go = "0.21"
|
||||
tree-sitter-go-mod = { git = "https://github.com/SomeoneToIgnore/tree-sitter-go-mod", rev = "8c1f54f12bb4c846336b634bc817645d6f35d641", package = "tree-sitter-gomod"}
|
||||
tree-sitter-go-mod = { git = "https://github.com/SomeoneToIgnore/tree-sitter-go-mod", rev = "8c1f54f12bb4c846336b634bc817645d6f35d641", package = "tree-sitter-gomod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work", rev = "dcbabff454703c3a4bc98a23cf8778d4be46fd22" }
|
||||
tree-sitter-heex = { git = "https://github.com/phoenixframework/tree-sitter-heex", rev = "6dd0303acf7138dd2b9b432a229e16539581c701" }
|
||||
tree-sitter-html = "0.20"
|
||||
@@ -432,20 +448,19 @@ url = "2.2"
|
||||
uuid = { version = "1.1.2", features = ["v4", "v5", "serde"] }
|
||||
wasmparser = "0.201"
|
||||
wasm-encoder = "0.201"
|
||||
wasmtime = { version = "19.0.2", default-features = false, features = [
|
||||
wasmtime = { version = "21.0.1", default-features = false, features = [
|
||||
"async",
|
||||
"demangle",
|
||||
"runtime",
|
||||
"cranelift",
|
||||
"component-model",
|
||||
] }
|
||||
wasmtime-wasi = "19.0.2"
|
||||
wasmtime-wasi = "21.0.1"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
sys-locale = "0.3.1"
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.57"
|
||||
version = "0.58"
|
||||
features = [
|
||||
"implement",
|
||||
"Foundation_Numerics",
|
||||
@@ -465,7 +480,6 @@ features = [
|
||||
"Win32_Security",
|
||||
"Win32_Security_Credentials",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Com",
|
||||
"Win32_System_Com_StructuredStorage",
|
||||
"Win32_System_DataExchange",
|
||||
@@ -484,6 +498,10 @@ features = [
|
||||
"Win32_UI_WindowsAndMessaging",
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
# Patch Tree-sitter for updated wasmtime.
|
||||
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7f4a57817d58a2f134fe863674acad6bbf007228" }
|
||||
|
||||
[profile.dev]
|
||||
split-debuginfo = "unpacked"
|
||||
debug = "limited"
|
||||
@@ -535,13 +553,6 @@ single_range_in_vec_init = "allow"
|
||||
style = { level = "allow", priority = -1 }
|
||||
|
||||
# Individual rules that have violations in the codebase:
|
||||
almost_complete_range = "allow"
|
||||
arc_with_non_send_sync = "allow"
|
||||
borrowed_box = "allow"
|
||||
let_underscore_future = "allow"
|
||||
map_entry = "allow"
|
||||
non_canonical_partial_ord_impl = "allow"
|
||||
reversed_empty_ranges = "allow"
|
||||
type_complexity = "allow"
|
||||
|
||||
[workspace.metadata.cargo-machete]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# syntax = docker/dockerfile:1.2
|
||||
|
||||
FROM rust:1.79-bookworm as builder
|
||||
FROM rust:1.80-bookworm as builder
|
||||
WORKDIR app
|
||||
COPY . .
|
||||
|
||||
|
||||
1
assets/icons/eye.svg
Normal file
1
assets/icons/eye.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-eye"><path d="M2.062 12.348a1 1 0 0 1 0-.696 10.75 10.75 0 0 1 19.876 0 1 1 0 0 1 0 .696 10.75 10.75 0 0 1-19.876 0"/><circle cx="12" cy="12" r="3"/></svg>
|
||||
|
After Width: | Height: | Size: 358 B |
1
assets/icons/file_code.svg
Normal file
1
assets/icons/file_code.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-code"><path d="M10 12.5 8 15l2 2.5"/><path d="m14 12.5 2 2.5-2 2.5"/><path d="M14 2v4a2 2 0 0 0 2 2h4"/><path d="M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7z"/></svg>
|
||||
|
After Width: | Height: | Size: 388 B |
1
assets/icons/file_text.svg
Normal file
1
assets/icons/file_text.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-file-text"><path d="M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7Z"/><path d="M14 2v4a2 2 0 0 0 2 2h4"/><path d="M10 9H8"/><path d="M16 13H8"/><path d="M16 17H8"/></svg>
|
||||
|
After Width: | Height: | Size: 384 B |
6
assets/icons/sliders-alt.svg
Normal file
6
assets/icons/sliders-alt.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3 4H8" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 10L11 10" stroke="black" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<circle cx="4" cy="10" r="1.875" stroke="black" stroke-width="1.75"/>
|
||||
<circle cx="10" cy="4" r="1.875" stroke="black" stroke-width="1.75"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 450 B |
@@ -40,7 +40,6 @@
|
||||
"backspace": "editor::Backspace",
|
||||
"shift-backspace": "editor::Backspace",
|
||||
"delete": "editor::Delete",
|
||||
"ctrl-d": "editor::Delete",
|
||||
"tab": "editor::Tab",
|
||||
"shift-tab": "editor::TabPrev",
|
||||
"ctrl-k": "editor::CutToEndOfLine",
|
||||
@@ -269,6 +268,7 @@
|
||||
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
|
||||
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
|
||||
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
|
||||
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
|
||||
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
|
||||
|
||||
@@ -26,6 +26,9 @@
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
"buffer_font_family": "Zed Plex Mono",
|
||||
// Set the buffer text's font fallbacks, this will be merged with
|
||||
// the platform's default fallbacks.
|
||||
"buffer_font_fallbacks": [],
|
||||
// The OpenType features to enable for text in the editor.
|
||||
"buffer_font_features": {
|
||||
// Disable ligatures:
|
||||
@@ -47,8 +50,11 @@
|
||||
// },
|
||||
"buffer_line_height": "comfortable",
|
||||
// The name of a font to use for rendering text in the UI
|
||||
// (On macOS) You can set this to ".SystemUIFont" to use the system font
|
||||
// You can set this to ".SystemUIFont" to use the system font
|
||||
"ui_font_family": "Zed Plex Sans",
|
||||
// Set the UI's font fallbacks, this will be merged with the platform's
|
||||
// default font fallbacks.
|
||||
"ui_font_fallbacks": [],
|
||||
// The OpenType features to enable for text in the UI
|
||||
"ui_font_features": {
|
||||
// Disable ligatures:
|
||||
@@ -312,7 +318,7 @@
|
||||
"auto_reveal_entries": true,
|
||||
// Whether to fold directories automatically and show compact folders
|
||||
// (e.g. "a/b/c" ) when a directory has only one subdirectory inside.
|
||||
"auto_fold_dirs": false,
|
||||
"auto_fold_dirs": true,
|
||||
/// Scrollbar-related settings
|
||||
"scrollbar": {
|
||||
/// When to show the scrollbar in the project panel.
|
||||
@@ -675,6 +681,10 @@
|
||||
// Set the terminal's font family. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font family.
|
||||
// "font_family": "Zed Plex Mono",
|
||||
// Set the terminal's font fallbacks. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font fallbacks.
|
||||
// This will be merged with the platform's default font fallbacks
|
||||
// "font_fallbacks": ["FiraCode Nerd Fonts"],
|
||||
// Sets the maximum number of lines in the terminal's scrollback buffer.
|
||||
// Default: 10_000, maximum: 100_000 (all bigger values set will be treated as 100_000), 0 disables the scrolling.
|
||||
// Existing terminals will not pick up this change until they are recreated.
|
||||
@@ -860,6 +870,9 @@
|
||||
"openai": {
|
||||
"api_url": "https://api.openai.com/v1"
|
||||
},
|
||||
"google": {
|
||||
"api_url": "https://generativelanguage.googleapis.com"
|
||||
},
|
||||
"ollama": {
|
||||
"api_url": "http://localhost:11434"
|
||||
}
|
||||
@@ -965,5 +978,21 @@
|
||||
// {
|
||||
// "W": "workspace::Save"
|
||||
// }
|
||||
"command_aliases": {}
|
||||
"command_aliases": {},
|
||||
// ssh_connections is an array of ssh connections.
|
||||
// By default this setting is null, which disables the direct ssh connection support.
|
||||
// You can configure these from `project: Open Remote` in the command palette.
|
||||
// Zed's ssh support will pull configuration from your ~/.ssh too.
|
||||
// Examples:
|
||||
// [
|
||||
// {
|
||||
// "host": "example-box",
|
||||
// "projects": [
|
||||
// {
|
||||
// "paths": ["/home/user/code/zed"]
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// ]
|
||||
"ssh_connections": null
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
use std::time::Duration;
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
@@ -70,120 +70,53 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
pub async fn complete(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<Response> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let response_message: Response = serde_json::from_slice(&body)?;
|
||||
Ok(response_message)
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
#[serde(serialize_with = "serialize_request_model")]
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub system: String,
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&model.id())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseEvent {
|
||||
MessageStart {
|
||||
message: ResponseMessage,
|
||||
},
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: ContentBlock,
|
||||
},
|
||||
Ping {},
|
||||
ContentBlockDelta {
|
||||
index: u32,
|
||||
delta: TextDelta,
|
||||
},
|
||||
ContentBlockStop {
|
||||
index: u32,
|
||||
},
|
||||
MessageDelta {
|
||||
delta: ResponseMessage,
|
||||
usage: Usage,
|
||||
},
|
||||
MessageStop {},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ResponseMessage {
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub model: Option<String>,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<u32>,
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum TextDelta {
|
||||
TextDelta { text: String },
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
) -> Result<BoxStream<'static, Result<Event>>> {
|
||||
let request = StreamingRequest {
|
||||
base: request,
|
||||
stream: true,
|
||||
};
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
@@ -195,7 +128,9 @@ pub async fn stream_completion(
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
@@ -220,7 +155,7 @@ pub async fn stream_completion(
|
||||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
||||
match serde_json::from_str::<Event>(body_str) {
|
||||
Ok(_) => Err(anyhow!(
|
||||
"Unexpected success response while expecting an error: {}",
|
||||
body_str,
|
||||
@@ -234,42 +169,183 @@ pub async fn stream_completion(
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use http::IsahcHttpClient;
|
||||
pub fn extract_text_from_events(
|
||||
response: impl Stream<Item = Result<Event>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
response.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
Event::ContentBlockStart { content_block, .. } => match content_block {
|
||||
Content::Text { text } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
Event::ContentBlockDelta { delta, .. } => match delta {
|
||||
ContentDelta::TextDelta { text } => Some(Ok(text)),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stream_completion_success() {
|
||||
// let http_client = IsahcHttpClient::new().unwrap();
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub content: Vec<Content>,
|
||||
}
|
||||
|
||||
// let request = Request {
|
||||
// model: Model::Claude3Opus,
|
||||
// messages: vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: "Ping".to_string(),
|
||||
// }],
|
||||
// stream: true,
|
||||
// system: "Respond to ping with pong".to_string(),
|
||||
// max_tokens: 4096,
|
||||
// };
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
// let stream = stream_completion(
|
||||
// &http_client,
|
||||
// "https://api.anthropic.com",
|
||||
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
||||
// request,
|
||||
// )
|
||||
// .await
|
||||
// .unwrap();
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Content {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image")]
|
||||
Image { source: ImageSource },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
// stream
|
||||
// .for_each(|event| async {
|
||||
// match event {
|
||||
// Ok(event) => println!("{:?}", event),
|
||||
// Err(e) => eprintln!("Error: {:?}", e),
|
||||
// }
|
||||
// })
|
||||
// .await;
|
||||
// }
|
||||
// }
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ImageSource {
|
||||
#[serde(rename = "type")]
|
||||
pub source_type: String,
|
||||
pub media_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<Tool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<Metadata>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop_sequences: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct StreamingRequest {
|
||||
#[serde(flatten)]
|
||||
pub base: Request,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Metadata {
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub input_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Response {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: Role,
|
||||
pub content: Vec<Content>,
|
||||
pub model: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub stop_reason: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum Event {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: Response },
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: Content,
|
||||
},
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: usize },
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta { delta: MessageDelta, usage: Usage },
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
#[serde(rename = "error")]
|
||||
Error { error: ApiError },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct MessageDelta {
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
@@ -75,7 +75,6 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
picker.workspace = true
|
||||
roxmltree = "0.20.0"
|
||||
|
||||
[dev-dependencies]
|
||||
completion = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -38,7 +38,7 @@ Considering these aspects will ensure our conversation view design is optimized
|
||||
|
||||
@nate> 2 feels like it isn't important at the moment, we can explore that later. Let's start with 4, which I think will lead us to discussion 3 and 5.
|
||||
|
||||
#zed share your thoughts on the points we need to consider to design a layout and visualization for a conversation view between you (#zed) and multuple peoople, or between multiple people and multiple bots (you and other bots).
|
||||
#zed share your thoughts on the points we need to consider to design a layout and visualization for a conversation view between you (#zed) and multiple people, or between multiple people and multiple bots (you and other bots).
|
||||
|
||||
@nathan> Agreed. I'm interested in threading I think more than anything. Or 4 yeah. I think we need to scope the threading conversation. Also, asking #zed to propose the solution... not sure it will be that effective but it's worth a try...
|
||||
|
||||
|
||||
@@ -1207,12 +1207,16 @@ impl ContextEditor {
|
||||
|
||||
fn apply_edit_step(&mut self, cx: &mut ViewContext<Self>) -> bool {
|
||||
if let Some(step) = self.active_edit_step.as_ref() {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for assist_id in &step.assist_ids {
|
||||
assistant.start_assist(*assist_id, cx);
|
||||
}
|
||||
!step.assist_ids.is_empty()
|
||||
})
|
||||
let assist_ids = step.assist_ids.clone();
|
||||
cx.window_context().defer(|cx| {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for assist_id in assist_ids {
|
||||
assistant.start_assist(assist_id, cx);
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
!step.assist_ids.is_empty()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -1261,11 +1265,7 @@ impl ContextEditor {
|
||||
.collect::<String>()
|
||||
));
|
||||
match &step.operations {
|
||||
Some(EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
}) => {
|
||||
output.push_str(&format!("Raw Output:\n{raw_output}\n"));
|
||||
Some(EditStepOperations::Ready(operations)) => {
|
||||
output.push_str("Parsed Operations:\n");
|
||||
for op in operations {
|
||||
output.push_str(&format!(" {:?}\n", op));
|
||||
@@ -1769,13 +1769,12 @@ impl ContextEditor {
|
||||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||
.unwrap()
|
||||
};
|
||||
let initial_text = suggestion.prepend_newline.then(|| "\n".into());
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assist_ids.push(assistant.suggest_assist(
|
||||
&editor,
|
||||
range,
|
||||
description,
|
||||
initial_text,
|
||||
suggestion.initial_insertion,
|
||||
Some(workspace.clone()),
|
||||
assistant_panel.upgrade().as_ref(),
|
||||
cx,
|
||||
@@ -1837,9 +1836,11 @@ impl ContextEditor {
|
||||
.anchor_in_excerpt(excerpt_id, suggestion.range.end)
|
||||
.unwrap()
|
||||
};
|
||||
let initial_text =
|
||||
suggestion.prepend_newline.then(|| "\n".to_string());
|
||||
inline_assist_suggestions.push((range, description, initial_text));
|
||||
inline_assist_suggestions.push((
|
||||
range,
|
||||
description,
|
||||
suggestion.initial_insertion,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1850,12 +1851,12 @@ impl ContextEditor {
|
||||
.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?;
|
||||
cx.update(|cx| {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
for (range, description, initial_text) in inline_assist_suggestions {
|
||||
for (range, description, initial_insertion) in inline_assist_suggestions {
|
||||
assist_ids.push(assistant.suggest_assist(
|
||||
&editor,
|
||||
range,
|
||||
description,
|
||||
initial_text,
|
||||
initial_insertion,
|
||||
Some(workspace.clone()),
|
||||
assistant_panel.upgrade().as_ref(),
|
||||
cx,
|
||||
@@ -2163,7 +2164,7 @@ impl ContextEditor {
|
||||
let button_text = match self.edit_step_for_cursor(cx) {
|
||||
Some(edit_step) => match &edit_step.operations {
|
||||
Some(EditStepOperations::Pending(_)) => "Computing Changes...",
|
||||
Some(EditStepOperations::Parsed { .. }) => "Apply Changes",
|
||||
Some(EditStepOperations::Ready(_)) => "Apply Changes",
|
||||
None => "Send",
|
||||
},
|
||||
None => "Send",
|
||||
|
||||
@@ -249,9 +249,7 @@ impl AssistantSettingsContent {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
|
||||
default_model: CloudModel::from_id(&model).ok(),
|
||||
});
|
||||
log::warn!("attempted to set zed.dev model on outdated settings");
|
||||
}
|
||||
"anthropic" => {
|
||||
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider,
|
||||
MessageId, MessageStatus,
|
||||
prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
|
||||
LanguageModelCompletionProvider, MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_slash_command::{
|
||||
@@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
|
||||
use language::{
|
||||
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
|
||||
};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::{LanguageModelRequest, Role};
|
||||
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp,
|
||||
@@ -352,7 +352,7 @@ pub struct EditSuggestion {
|
||||
pub range: Range<language::Anchor>,
|
||||
/// If None, assume this is a suggestion to delete the range rather than transform it.
|
||||
pub description: Option<String>,
|
||||
pub prepend_newline: bool,
|
||||
pub initial_insertion: Option<InitialInsertion>,
|
||||
}
|
||||
|
||||
impl EditStep {
|
||||
@@ -361,7 +361,7 @@ impl EditStep {
|
||||
project: &Model<Project>,
|
||||
cx: &AppContext,
|
||||
) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
|
||||
let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else {
|
||||
let Some(EditStepOperations::Ready(operations)) = &self.operations else {
|
||||
return Task::ready(HashMap::default());
|
||||
};
|
||||
|
||||
@@ -471,32 +471,28 @@ impl EditStep {
|
||||
}
|
||||
|
||||
pub enum EditStepOperations {
|
||||
Pending(Task<Result<()>>),
|
||||
Parsed {
|
||||
operations: Vec<EditOperation>,
|
||||
raw_output: String,
|
||||
},
|
||||
Pending(Task<Option<()>>),
|
||||
Ready(Vec<EditOperation>),
|
||||
}
|
||||
|
||||
impl Debug for EditStepOperations {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
|
||||
EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
} => f
|
||||
EditStepOperations::Ready(operations) => f
|
||||
.debug_struct("EditStepOperations::Parsed")
|
||||
.field("operations", operations)
|
||||
.field("raw_output", raw_output)
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
/// A description of an operation to apply to one location in the codebase.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
pub struct EditOperation {
|
||||
/// The path to the file containing the relevant operation
|
||||
pub path: String,
|
||||
#[serde(flatten)]
|
||||
pub kind: EditOperationKind,
|
||||
}
|
||||
|
||||
@@ -523,7 +519,7 @@ impl EditOperation {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
let prepend_newline = kind.prepend_newline();
|
||||
let initial_insertion = kind.initial_insertion();
|
||||
let suggestion_range = if let Some(symbol) = kind.symbol() {
|
||||
let outline = buffer
|
||||
.update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
|
||||
@@ -532,7 +528,21 @@ impl EditOperation {
|
||||
.path_candidates
|
||||
.iter()
|
||||
.find(|item| item.string == symbol)
|
||||
.context("symbol not found")?;
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"symbol {:?} not found in path {:?}.\ncandidates: {:?}.\nparse status: {:?}. text:\n{}",
|
||||
symbol,
|
||||
path,
|
||||
outline
|
||||
.path_candidates
|
||||
.iter()
|
||||
.map(|candidate| &candidate.string)
|
||||
.collect::<Vec<_>>(),
|
||||
*parse_status.borrow(),
|
||||
buffer.read_with(&cx, |buffer, _| buffer.text()).unwrap_or_else(|_| "error".to_string())
|
||||
)
|
||||
})?;
|
||||
|
||||
buffer.update(&mut cx, |buffer, _| {
|
||||
let outline_item = &outline.items[candidate.id];
|
||||
let symbol_range = outline_item.range.to_point(buffer);
|
||||
@@ -587,39 +597,61 @@ impl EditOperation {
|
||||
EditSuggestion {
|
||||
range: suggestion_range,
|
||||
description: kind.description().map(ToString::to_string),
|
||||
prepend_newline,
|
||||
initial_insertion,
|
||||
},
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum EditOperationKind {
|
||||
/// Rewrite the specified symbol in its entirely based on the given description.
|
||||
Update {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Create a new file with the given path based on the given description.
|
||||
Create {
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description before the specified symbol.
|
||||
InsertSiblingBefore {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol based on the given description after the specified symbol.
|
||||
InsertSiblingAfter {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the start.
|
||||
PrependChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Insert a new symbol as a child of the specified symbol at the end.
|
||||
AppendChild {
|
||||
/// An optional full path to the symbol to be rewritten from the provided list.
|
||||
/// If not provided, the edit should be applied at the top of the file.
|
||||
symbol: Option<String>,
|
||||
/// A brief one-line description of the change that should be applied.
|
||||
description: String,
|
||||
},
|
||||
/// Delete the specified symbol.
|
||||
Delete {
|
||||
/// A full path to the symbol to be rewritten from the provided list.
|
||||
symbol: String,
|
||||
},
|
||||
}
|
||||
@@ -649,13 +681,13 @@ impl EditOperationKind {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prepend_newline(&self) -> bool {
|
||||
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
||||
match self {
|
||||
Self::PrependChild { .. }
|
||||
| Self::AppendChild { .. }
|
||||
| Self::InsertSiblingAfter { .. }
|
||||
| Self::InsertSiblingBefore { .. } => true,
|
||||
_ => false,
|
||||
EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
||||
EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1128,12 +1160,10 @@ impl Context {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
@@ -1289,7 +1319,24 @@ impl Context {
|
||||
&self,
|
||||
edit_step: &EditStep,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
) -> Task<Option<()>> {
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct EditTool {
|
||||
/// A sequence of operations to apply to the codebase.
|
||||
/// When multiple operations are required for a step, be sure to include multiple operations in this list.
|
||||
operations: Vec<EditOperation>,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for EditTool {
|
||||
fn name() -> String {
|
||||
"edit".into()
|
||||
}
|
||||
|
||||
fn description() -> String {
|
||||
"suggest edits to one or more locations in the codebase".into()
|
||||
}
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
let edit_step_range = edit_step.source_range.clone();
|
||||
let step_text = self
|
||||
@@ -1298,160 +1345,41 @@ impl Context {
|
||||
.text_for_range(edit_step_range.clone())
|
||||
.collect::<String>();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
|
||||
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
let mut prompt = prompt_store.operations_prompt();
|
||||
prompt.push_str(&step_text);
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
let raw_output = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).complete(request, cx)
|
||||
let tool_use = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.use_tool::<EditTool>(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let operations = Self::parse_edit_operations(&raw_output);
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let step_index = this
|
||||
.edit_steps
|
||||
.binary_search_by(|step| {
|
||||
step.source_range
|
||||
.cmp(&edit_step_range, this.buffer.read(cx))
|
||||
})
|
||||
.map_err(|_| anyhow!("edit step not found"))?;
|
||||
if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
|
||||
edit_step.operations = Some(EditStepOperations::Parsed {
|
||||
operations,
|
||||
raw_output,
|
||||
});
|
||||
cx.emit(ContextEvent::EditStepsChanged);
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_edit_operations(xml: &str) -> Vec<EditOperation> {
|
||||
let Some(start_ix) = xml.find("<operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Some(end_ix) = xml[start_ix..].find("</operations>") else {
|
||||
return Vec::new();
|
||||
};
|
||||
let end_ix = end_ix + start_ix + "</operations>".len();
|
||||
|
||||
let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err();
|
||||
doc.map_or(Vec::new(), |doc| {
|
||||
doc.root_element()
|
||||
.children()
|
||||
.map(|node| {
|
||||
let tag_name = node.tag_name().name();
|
||||
let path = node
|
||||
.attribute("path")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'path'")
|
||||
})?
|
||||
.to_string();
|
||||
let kind = match tag_name {
|
||||
"update" => EditOperationKind::Update {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"create" => EditOperationKind::Create {
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_after" => EditOperationKind::InsertSiblingAfter {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"insert_sibling_before" => EditOperationKind::InsertSiblingBefore {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"prepend_child" => EditOperationKind::PrependChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"append_child" => EditOperationKind::AppendChild {
|
||||
symbol: node.attribute("symbol").map(String::from),
|
||||
description: node
|
||||
.attribute("description")
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"invalid node {node:?}, missing attribute 'description'"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
"delete" => EditOperationKind::Delete {
|
||||
symbol: node
|
||||
.attribute("symbol")
|
||||
.with_context(|| {
|
||||
format!("invalid node {node:?}, missing attribute 'symbol'")
|
||||
})?
|
||||
.to_string(),
|
||||
},
|
||||
_ => return Err(anyhow!("invalid node {node:?}")),
|
||||
};
|
||||
anyhow::Ok(EditOperation { path, kind })
|
||||
})
|
||||
.filter_map(|op| op.log_err())
|
||||
.collect()
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3068,55 +2996,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_edit_operations() {
|
||||
let operations = indoc! {r#"
|
||||
Here are the operations to make all fields of the Canvas struct private:
|
||||
|
||||
<operations>
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub pixels" description="Remove pub keyword from pixels field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub size" description="Remove pub keyword from size field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub stride" description="Remove pub keyword from stride field" />
|
||||
<update path="font-kit/src/canvas.rs" symbol="pub struct Canvas pub format" description="Remove pub keyword from format field" />
|
||||
</operations>
|
||||
"#};
|
||||
|
||||
let parsed_operations = Context::parse_edit_operations(operations);
|
||||
assert_eq!(
|
||||
parsed_operations,
|
||||
vec![
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub pixels".to_string(),
|
||||
description: "Remove pub keyword from pixels field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub size".to_string(),
|
||||
description: "Remove pub keyword from size field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub stride".to_string(),
|
||||
description: "Remove pub keyword from stride field".to_string(),
|
||||
},
|
||||
},
|
||||
EditOperation {
|
||||
path: "font-kit/src/canvas.rs".to_string(),
|
||||
kind: EditOperationKind::Update {
|
||||
symbol: "pub struct Canvas pub format".to_string(),
|
||||
description: "Remove pub keyword from format field".to_string(),
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_serialization(cx: &mut TestAppContext) {
|
||||
let settings_store = cx.update(SettingsStore::test);
|
||||
|
||||
@@ -17,7 +17,7 @@ use editor::{
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
future::LocalBoxFuture,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
stream::{self, BoxStream},
|
||||
SinkExt, Stream, StreamExt,
|
||||
};
|
||||
@@ -36,7 +36,7 @@ use similar::TextDiff;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
cmp,
|
||||
future::Future,
|
||||
future::{self, Future},
|
||||
mem,
|
||||
ops::{Range, RangeInclusive},
|
||||
pin::Pin,
|
||||
@@ -46,7 +46,7 @@ use std::{
|
||||
};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{prelude::*, IconButtonShape, Tooltip};
|
||||
use util::RangeExt;
|
||||
use util::{RangeExt, ResultExt};
|
||||
use workspace::{notifications::NotificationId, Toast, Workspace};
|
||||
|
||||
pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
|
||||
@@ -187,7 +187,13 @@ impl InlineAssistant {
|
||||
let [prompt_block_id, end_block_id] =
|
||||
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
|
||||
|
||||
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
|
||||
assists.push((
|
||||
assist_id,
|
||||
range,
|
||||
prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
));
|
||||
}
|
||||
|
||||
let editor_assists = self
|
||||
@@ -195,7 +201,7 @@ impl InlineAssistant {
|
||||
.entry(editor.downgrade())
|
||||
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
|
||||
let mut assist_group = InlineAssistGroup::new();
|
||||
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
|
||||
self.assists.insert(
|
||||
assist_id,
|
||||
InlineAssist::new(
|
||||
@@ -206,6 +212,7 @@ impl InlineAssistant {
|
||||
&prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
range,
|
||||
prompt_editor.read(cx).codegen.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -227,7 +234,7 @@ impl InlineAssistant {
|
||||
editor: &View<Editor>,
|
||||
mut range: Range<Anchor>,
|
||||
initial_prompt: String,
|
||||
initial_insertion: Option<String>,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
assistant_panel: Option<&View<AssistantPanel>>,
|
||||
cx: &mut WindowContext,
|
||||
@@ -239,22 +246,30 @@ impl InlineAssistant {
|
||||
let assist_id = self.next_assist_id.post_inc();
|
||||
|
||||
let buffer = editor.read(cx).buffer().clone();
|
||||
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
|
||||
buffer.end_transaction(cx)
|
||||
})
|
||||
});
|
||||
{
|
||||
let snapshot = buffer.read(cx).read(cx);
|
||||
|
||||
range.start = range.start.bias_left(&buffer.read(cx).read(cx));
|
||||
range.end = range.end.bias_right(&buffer.read(cx).read(cx));
|
||||
let mut point_range = range.to_point(&snapshot);
|
||||
if point_range.is_empty() {
|
||||
point_range.start.column = 0;
|
||||
point_range.end.column = 0;
|
||||
} else {
|
||||
point_range.start.column = 0;
|
||||
if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
|
||||
point_range.end.row -= 1;
|
||||
}
|
||||
point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
|
||||
}
|
||||
|
||||
range.start = snapshot.anchor_before(point_range.start);
|
||||
range.end = snapshot.anchor_after(point_range.end);
|
||||
}
|
||||
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
editor.read(cx).buffer().clone(),
|
||||
range.clone(),
|
||||
prepend_transaction_id,
|
||||
initial_insertion,
|
||||
self.telemetry.clone(),
|
||||
cx,
|
||||
)
|
||||
@@ -295,6 +310,7 @@ impl InlineAssistant {
|
||||
&prompt_editor,
|
||||
prompt_block_id,
|
||||
end_block_id,
|
||||
range,
|
||||
prompt_editor.read(cx).codegen.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -445,7 +461,7 @@ impl InlineAssistant {
|
||||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||
for assist_id in &editor_assists.assist_ids {
|
||||
let assist = &self.assists[assist_id];
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
||||
let assist_range = assist.range.to_offset(&buffer);
|
||||
if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
|
||||
{
|
||||
if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
|
||||
@@ -473,7 +489,7 @@ impl InlineAssistant {
|
||||
let buffer = editor.buffer().read(cx).snapshot(cx);
|
||||
for assist_id in &editor_assists.assist_ids {
|
||||
let assist = &self.assists[assist_id];
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
|
||||
let assist_range = assist.range.to_offset(&buffer);
|
||||
if assist.decorations.is_some()
|
||||
&& assist_range.contains(&selection.start)
|
||||
&& assist_range.contains(&selection.end)
|
||||
@@ -551,7 +567,7 @@ impl InlineAssistant {
|
||||
assist.codegen.read(cx).status,
|
||||
CodegenStatus::Error(_) | CodegenStatus::Done
|
||||
) {
|
||||
let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
|
||||
let assist_range = assist.range.to_offset(&snapshot);
|
||||
if edited_ranges
|
||||
.iter()
|
||||
.any(|range| range.overlaps(&assist_range))
|
||||
@@ -721,7 +737,7 @@ impl InlineAssistant {
|
||||
});
|
||||
}
|
||||
|
||||
let position = assist.codegen.read(cx).range.start;
|
||||
let position = assist.range.start;
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.change_selections(None, cx, |selections| {
|
||||
selections.select_anchor_ranges([position..position])
|
||||
@@ -740,8 +756,7 @@ impl InlineAssistant {
|
||||
.0 as f32;
|
||||
} else {
|
||||
let snapshot = editor.snapshot(cx);
|
||||
let codegen = assist.codegen.read(cx);
|
||||
let start_row = codegen
|
||||
let start_row = assist
|
||||
.range
|
||||
.start
|
||||
.to_display_point(&snapshot.display_snapshot)
|
||||
@@ -829,11 +844,7 @@ impl InlineAssistant {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(user_prompt) = assist
|
||||
.decorations
|
||||
.as_ref()
|
||||
.map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
|
||||
else {
|
||||
let Some(user_prompt) = assist.user_prompt(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -843,139 +854,19 @@ impl InlineAssistant {
|
||||
self.prompt_history.pop_front();
|
||||
}
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|m| m.telemetry_id())
|
||||
.unwrap_or_default();
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.request_for_inline_assist(assist_id, cx);
|
||||
let mut cx = cx.to_async();
|
||||
async move {
|
||||
let request = request.await?;
|
||||
let chunks = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.stream_completion(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
Ok(chunks.boxed())
|
||||
}
|
||||
.boxed_local()
|
||||
};
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(telemetry_id, chunks, cx);
|
||||
});
|
||||
}
|
||||
let assistant_panel_context = assist.assistant_panel_context(cx);
|
||||
|
||||
fn request_for_inline_assist(
|
||||
&self,
|
||||
assist_id: InlineAssistId,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<LanguageModelRequest>> {
|
||||
cx.spawn(|mut cx| async move {
|
||||
let (user_prompt, context_request, project_name, buffer, range) =
|
||||
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
|
||||
let assist = this.assists.get(&assist_id).context("invalid assist")?;
|
||||
let decorations = assist.decorations.as_ref().context("invalid assist")?;
|
||||
let editor = assist.editor.upgrade().context("invalid assist")?;
|
||||
let user_prompt = decorations.prompt_editor.read(cx).prompt(cx);
|
||||
let context_request = if assist.include_context {
|
||||
assist.workspace.as_ref().and_then(|workspace| {
|
||||
let workspace = workspace.upgrade()?.read(cx);
|
||||
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||
Some(
|
||||
assistant_panel
|
||||
.read(cx)
|
||||
.active_context(cx)?
|
||||
.read(cx)
|
||||
.to_completion_request(cx),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let project_name = assist.workspace.as_ref().and_then(|workspace| {
|
||||
let workspace = workspace.upgrade()?;
|
||||
Some(
|
||||
workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.read(cx)
|
||||
.worktree_root_names(cx)
|
||||
.collect::<Vec<&str>>()
|
||||
.join("/"),
|
||||
)
|
||||
});
|
||||
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let range = assist.codegen.read(cx).range.clone();
|
||||
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
|
||||
})??;
|
||||
|
||||
let language = buffer.language_at(range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
} else {
|
||||
Some(language.name())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Higher Temperature increases the randomness of model outputs.
|
||||
// If Markdown or No Language is Known, increase the randomness for more creative output
|
||||
// If Code, decrease temperature to get more deterministic outputs
|
||||
let temperature = if let Some(language) = language_name.clone() {
|
||||
if language.as_ref() == "Markdown" {
|
||||
1.0
|
||||
} else {
|
||||
0.5
|
||||
}
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let prompt = cx
|
||||
.background_executor()
|
||||
.spawn(async move {
|
||||
let language_name = language_name.as_deref();
|
||||
let start = buffer.point_to_buffer_offset(range.start);
|
||||
let end = buffer.point_to_buffer_offset(range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
return Err(anyhow!("invalid transformation range"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("invalid transformation range"));
|
||||
};
|
||||
generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(context_request) = context_request {
|
||||
messages = context_request.messages;
|
||||
}
|
||||
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
assist
|
||||
.codegen
|
||||
.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
assist.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
|
||||
@@ -1006,12 +897,11 @@ impl InlineAssistant {
|
||||
let codegen = assist.codegen.read(cx);
|
||||
foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
|
||||
|
||||
if codegen.edit_position != codegen.range.end {
|
||||
gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
|
||||
}
|
||||
gutter_pending_ranges
|
||||
.push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
|
||||
|
||||
if codegen.range.start != codegen.edit_position {
|
||||
gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
|
||||
if let Some(edit_position) = codegen.edit_position {
|
||||
gutter_transformed_ranges.push(assist.range.start..edit_position);
|
||||
}
|
||||
|
||||
if assist.decorations.is_some() {
|
||||
@@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum InitialInsertion {
|
||||
NewlineBefore,
|
||||
NewlineAfter,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct InlineAssistId(usize);
|
||||
|
||||
@@ -1420,27 +1316,34 @@ impl Render for PromptEditor {
|
||||
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
|
||||
.justify_center()
|
||||
.gap_2()
|
||||
.child(ModelSelector::new(
|
||||
self.fs.clone(),
|
||||
IconButton::new("context", IconName::Settings)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(move |cx| {
|
||||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
),
|
||||
None,
|
||||
"Change Model",
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
))
|
||||
.child(
|
||||
ModelSelector::new(
|
||||
self.fs.clone(),
|
||||
IconButton::new("context", IconName::SlidersAlt)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(move |cx| {
|
||||
Tooltip::with_meta(
|
||||
format!(
|
||||
"Using {}",
|
||||
LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model()
|
||||
.map(|model| model.name().0)
|
||||
.unwrap_or_else(|| "No model selected".into()),
|
||||
),
|
||||
None,
|
||||
"Change Model",
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
)
|
||||
.with_info_text(
|
||||
"Inline edits use context\n\
|
||||
from the currently selected\n\
|
||||
assistant panel tab.",
|
||||
),
|
||||
)
|
||||
.children(
|
||||
if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
@@ -1622,17 +1525,16 @@ impl PromptEditor {
|
||||
let assist_id = self.id;
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| async move {
|
||||
cx.background_executor().timer(Duration::from_secs(1)).await;
|
||||
let request = cx
|
||||
let token_count = cx
|
||||
.update_global(|inline_assistant: &mut InlineAssistant, cx| {
|
||||
inline_assistant.request_for_inline_assist(assist_id, cx)
|
||||
})?
|
||||
let assist = inline_assistant
|
||||
.assists
|
||||
.get(&assist_id)
|
||||
.context("assist not found")?;
|
||||
anyhow::Ok(assist.count_tokens(cx))
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let token_count = cx
|
||||
.update(|cx| {
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
})?
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify();
|
||||
@@ -1825,6 +1727,7 @@ impl PromptEditor {
|
||||
},
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(1.3),
|
||||
@@ -1844,6 +1747,7 @@ impl PromptEditor {
|
||||
|
||||
struct InlineAssist {
|
||||
group_id: InlineAssistGroupId,
|
||||
range: Range<Anchor>,
|
||||
editor: WeakView<Editor>,
|
||||
decorations: Option<InlineAssistDecorations>,
|
||||
codegen: Model<Codegen>,
|
||||
@@ -1862,6 +1766,7 @@ impl InlineAssist {
|
||||
prompt_editor: &View<PromptEditor>,
|
||||
prompt_block_id: CustomBlockId,
|
||||
end_block_id: CustomBlockId,
|
||||
range: Range<Anchor>,
|
||||
codegen: Model<Codegen>,
|
||||
workspace: Option<WeakView<Workspace>>,
|
||||
cx: &mut WindowContext,
|
||||
@@ -1877,6 +1782,7 @@ impl InlineAssist {
|
||||
removed_line_block_ids: HashSet::default(),
|
||||
end_block_id,
|
||||
}),
|
||||
range,
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
_subscriptions: vec![
|
||||
@@ -1952,6 +1858,41 @@ impl InlineAssist {
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
fn user_prompt(&self, cx: &AppContext) -> Option<String> {
|
||||
let decorations = self.decorations.as_ref()?;
|
||||
Some(decorations.prompt_editor.read(cx).prompt(cx))
|
||||
}
|
||||
|
||||
fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
|
||||
if self.include_context {
|
||||
let workspace = self.workspace.as_ref()?;
|
||||
let workspace = workspace.upgrade()?.read(cx);
|
||||
let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||
Some(
|
||||
assistant_panel
|
||||
.read(cx)
|
||||
.active_context(cx)?
|
||||
.read(cx)
|
||||
.to_completion_request(cx),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
|
||||
let Some(user_prompt) = self.user_prompt(cx) else {
|
||||
return future::ready(Err(anyhow!("no user prompt"))).boxed();
|
||||
};
|
||||
let assistant_panel_context = self.assistant_panel_context(cx);
|
||||
self.codegen.read(cx).count_tokens(
|
||||
self.range.clone(),
|
||||
user_prompt,
|
||||
assistant_panel_context,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct InlineAssistDecorations {
|
||||
@@ -1971,16 +1912,15 @@ pub struct Codegen {
|
||||
buffer: Model<MultiBuffer>,
|
||||
old_buffer: Model<Buffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
range: Range<Anchor>,
|
||||
edit_position: Anchor,
|
||||
edit_position: Option<Anchor>,
|
||||
last_equal_ranges: Vec<Range<Anchor>>,
|
||||
prepend_transaction_id: Option<TransactionId>,
|
||||
generation_transaction_id: Option<TransactionId>,
|
||||
transaction_id: Option<TransactionId>,
|
||||
status: CodegenStatus,
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
}
|
||||
|
||||
enum CodegenStatus {
|
||||
@@ -2004,7 +1944,7 @@ impl Codegen {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
range: Range<Anchor>,
|
||||
prepend_transaction_id: Option<TransactionId>,
|
||||
initial_insertion: Option<InitialInsertion>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
@@ -2033,17 +1973,16 @@ impl Codegen {
|
||||
Self {
|
||||
buffer: buffer.clone(),
|
||||
old_buffer,
|
||||
edit_position: range.start,
|
||||
range,
|
||||
edit_position: None,
|
||||
snapshot,
|
||||
last_equal_ranges: Default::default(),
|
||||
prepend_transaction_id,
|
||||
generation_transaction_id: None,
|
||||
transaction_id: None,
|
||||
status: CodegenStatus::Idle,
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
initial_insertion,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2054,13 +1993,8 @@ impl Codegen {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
|
||||
if self.generation_transaction_id == Some(*transaction_id) {
|
||||
self.generation_transaction_id = None;
|
||||
self.generation = Task::ready(());
|
||||
cx.emit(CodegenEvent::Undone);
|
||||
} else if self.prepend_transaction_id == Some(*transaction_id) {
|
||||
self.prepend_transaction_id = None;
|
||||
self.generation_transaction_id = None;
|
||||
if self.transaction_id == Some(*transaction_id) {
|
||||
self.transaction_id = None;
|
||||
self.generation = Task::ready(());
|
||||
cx.emit(CodegenEvent::Undone);
|
||||
}
|
||||
@@ -2071,19 +2005,152 @@ impl Codegen {
|
||||
&self.last_equal_ranges
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
|
||||
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
&mut self,
|
||||
telemetry_id: String,
|
||||
mut edit_range: Range<Anchor>,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<()> {
|
||||
self.undo(cx);
|
||||
|
||||
// Handle initial insertion
|
||||
self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.start_transaction(cx);
|
||||
let offset = edit_range.start.to_offset(&self.snapshot);
|
||||
let edit_position;
|
||||
match initial_insertion {
|
||||
InitialInsertion::NewlineBefore => {
|
||||
buffer.edit([(offset..offset, "\n\n")], None, cx);
|
||||
self.snapshot = buffer.snapshot(cx);
|
||||
edit_position = self.snapshot.anchor_after(offset + 1);
|
||||
}
|
||||
InitialInsertion::NewlineAfter => {
|
||||
buffer.edit([(offset..offset, "\n")], None, cx);
|
||||
self.snapshot = buffer.snapshot(cx);
|
||||
edit_position = self.snapshot.anchor_after(offset);
|
||||
}
|
||||
}
|
||||
self.edit_position = Some(edit_position);
|
||||
edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
|
||||
buffer.end_transaction(cx)
|
||||
})
|
||||
} else {
|
||||
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
|
||||
None
|
||||
};
|
||||
|
||||
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
|
||||
.active_model_telemetry_id()
|
||||
.context("no active model")?;
|
||||
|
||||
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
|
||||
.trim()
|
||||
.to_lowercase()
|
||||
== "delete"
|
||||
{
|
||||
async { Ok(stream::empty().boxed()) }.boxed_local()
|
||||
} else {
|
||||
let request =
|
||||
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
|
||||
let chunks =
|
||||
LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
|
||||
async move { Ok(chunks.await?.boxed()) }.boxed_local()
|
||||
};
|
||||
self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
assistant_panel_context: Option<LanguageModelRequest>,
|
||||
edit_range: Range<Anchor>,
|
||||
cx: &AppContext,
|
||||
) -> LanguageModelRequest {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(edit_range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
|
||||
None
|
||||
} else {
|
||||
Some(language.name())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Higher Temperature increases the randomness of model outputs.
|
||||
// If Markdown or No Language is Known, increase the randomness for more creative output
|
||||
// If Code, decrease temperature to get more deterministic outputs
|
||||
let temperature = if let Some(language) = language_name.clone() {
|
||||
if language.as_ref() == "Markdown" {
|
||||
1.0
|
||||
} else {
|
||||
0.5
|
||||
}
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let language_name = language_name.as_deref();
|
||||
let start = buffer.point_to_buffer_offset(edit_range.start);
|
||||
let end = buffer.point_to_buffer_offset(edit_range.end);
|
||||
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
|
||||
let (start_buffer, start_buffer_offset) = start;
|
||||
let (end_buffer, end_buffer_offset) = end;
|
||||
if start_buffer.remote_id() == end_buffer.remote_id() {
|
||||
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
}
|
||||
} else {
|
||||
panic!("invalid transformation range");
|
||||
};
|
||||
let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if let Some(context_request) = assistant_panel_context {
|
||||
messages = context_request.messages;
|
||||
}
|
||||
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: prompt,
|
||||
});
|
||||
|
||||
LanguageModelRequest {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
&mut self,
|
||||
model_telemetry_id: String,
|
||||
edit_range: Range<Anchor>,
|
||||
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let range = self.range.clone();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
.text_for_range(range.start..range.end)
|
||||
.text_for_range(edit_range.start..edit_range.end)
|
||||
.collect::<Rope>();
|
||||
|
||||
let selection_start = range.start.to_point(&snapshot);
|
||||
let selection_start = edit_range.start.to_point(&snapshot);
|
||||
|
||||
// Start with the indentation of the first line in the selection
|
||||
let mut suggested_line_indent = snapshot
|
||||
@@ -2094,7 +2161,7 @@ impl Codegen {
|
||||
|
||||
// If the first line in the selection does not have indentation, check the following lines
|
||||
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
|
||||
for row in selection_start.row..=range.end.to_point(&snapshot).row {
|
||||
for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
|
||||
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
|
||||
// Prefer tabs if a line in the selection uses tabs as indentation
|
||||
if line_indent.kind == IndentKind::Tab {
|
||||
@@ -2105,19 +2172,13 @@ impl Codegen {
|
||||
}
|
||||
|
||||
let telemetry = self.telemetry.clone();
|
||||
self.edit_position = range.start;
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
let mut edit_start = edit_range.start.to_offset(&snapshot);
|
||||
self.generation = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let chunks = stream.await;
|
||||
let generate = async {
|
||||
let mut edit_start = range.start.to_offset(&snapshot);
|
||||
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
let diff: Task<anyhow::Result<()>> =
|
||||
cx.background_executor().spawn(async move {
|
||||
@@ -2207,7 +2268,7 @@ impl Codegen {
|
||||
telemetry.report_assistant_event(
|
||||
None,
|
||||
telemetry_events::AssistantKind::Inline,
|
||||
telemetry_id,
|
||||
model_telemetry_id,
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
@@ -2251,13 +2312,13 @@ impl Codegen {
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
this.edit_position = snapshot.anchor_after(edit_start);
|
||||
this.edit_position = Some(snapshot.anchor_after(edit_start));
|
||||
|
||||
buffer.end_transaction(cx)
|
||||
});
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
if let Some(first_transaction) = this.generation_transaction_id {
|
||||
if let Some(first_transaction) = this.transaction_id {
|
||||
// Group all assistant edits into the first transaction.
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.merge_transactions(
|
||||
@@ -2267,14 +2328,14 @@ impl Codegen {
|
||||
)
|
||||
});
|
||||
} else {
|
||||
this.generation_transaction_id = Some(transaction);
|
||||
this.transaction_id = Some(transaction);
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction(cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this.update_diff(cx);
|
||||
this.update_diff(edit_range.clone(), cx);
|
||||
cx.notify();
|
||||
})?;
|
||||
}
|
||||
@@ -2310,27 +2371,22 @@ impl Codegen {
|
||||
}
|
||||
|
||||
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
|
||||
if let Some(transaction_id) = self.prepend_transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
|
||||
if let Some(transaction_id) = self.generation_transaction_id.take() {
|
||||
if let Some(transaction_id) = self.transaction_id.take() {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
}
|
||||
|
||||
fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
|
||||
fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
|
||||
if self.diff.task.is_some() {
|
||||
self.diff.should_update = true;
|
||||
} else {
|
||||
self.diff.should_update = false;
|
||||
|
||||
let old_snapshot = self.snapshot.clone();
|
||||
let old_range = self.range.to_point(&old_snapshot);
|
||||
let old_range = edit_range.to_point(&old_snapshot);
|
||||
let new_snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let new_range = self.range.to_point(&new_snapshot);
|
||||
let new_range = edit_range.to_point(&new_snapshot);
|
||||
|
||||
self.diff.task = Some(cx.spawn(|this, mut cx| async move {
|
||||
let (deleted_row_ranges, inserted_row_ranges) = cx
|
||||
@@ -2411,7 +2467,7 @@ impl Codegen {
|
||||
this.diff.inserted_row_ranges = inserted_row_ranges;
|
||||
this.diff.task = None;
|
||||
if this.diff.should_update {
|
||||
this.update_diff(cx);
|
||||
this.update_diff(edit_range, cx);
|
||||
}
|
||||
cx.notify();
|
||||
})
|
||||
@@ -2618,12 +2674,14 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range,
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
@@ -2679,12 +2737,14 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
@@ -2744,12 +2804,14 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
@@ -2808,12 +2870,14 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
|
||||
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.start(
|
||||
codegen.handle_stream(
|
||||
String::new(),
|
||||
range.clone(),
|
||||
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
|
||||
use fs::Fs;
|
||||
use gpui::SharedString;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::update_settings_file;
|
||||
use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
|
||||
@@ -11,6 +12,7 @@ pub struct ModelSelector<T: PopoverTrigger> {
|
||||
handle: Option<PopoverMenuHandle<ContextMenu>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
trigger: T,
|
||||
info_text: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl<T: PopoverTrigger> ModelSelector<T> {
|
||||
@@ -19,6 +21,7 @@ impl<T: PopoverTrigger> ModelSelector<T> {
|
||||
handle: None,
|
||||
fs,
|
||||
trigger,
|
||||
info_text: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +29,11 @@ impl<T: PopoverTrigger> ModelSelector<T> {
|
||||
self.handle = Some(handle);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
|
||||
self.info_text = Some(text.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
||||
@@ -35,8 +43,20 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
|
||||
menu = menu.with_handle(handle);
|
||||
}
|
||||
|
||||
let info_text = self.info_text.clone();
|
||||
|
||||
menu.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
if let Some(info_text) = info_text.clone() {
|
||||
menu = menu
|
||||
.custom_row(move |_cx| {
|
||||
Label::new(info_text.clone())
|
||||
.color(Color::Muted)
|
||||
.into_any_element()
|
||||
})
|
||||
.separator();
|
||||
}
|
||||
|
||||
for (index, provider) in LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
|
||||
@@ -749,6 +749,7 @@ impl PromptLibrary {
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap();
|
||||
prompt_editor.token_count = Some(token_count);
|
||||
|
||||
@@ -6,8 +6,7 @@ pub fn generate_content_prompt(
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
_project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
@@ -15,14 +14,16 @@ pub fn generate_content_prompt(
|
||||
writeln!(
|
||||
prompt,
|
||||
"Here's a file of text that I'm going to ask you to make an edit to."
|
||||
)?;
|
||||
)
|
||||
.unwrap();
|
||||
"text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Here's a file of {language_name} that I'm going to ask you to make an edit to."
|
||||
)?;
|
||||
)
|
||||
.unwrap();
|
||||
"code"
|
||||
}
|
||||
};
|
||||
@@ -70,7 +71,7 @@ pub fn generate_content_prompt(
|
||||
write!(prompt, "</document>\n\n").unwrap();
|
||||
|
||||
if is_truncated {
|
||||
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?;
|
||||
writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap();
|
||||
}
|
||||
|
||||
if range.is_empty() {
|
||||
@@ -107,7 +108,7 @@ pub fn generate_content_prompt(
|
||||
prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```");
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn generate_terminal_assistant_prompt(
|
||||
|
||||
@@ -33,7 +33,7 @@ impl DiagnosticsSlashCommand {
|
||||
if query.is_empty() {
|
||||
let workspace = workspace.read(cx);
|
||||
let entries = workspace.recent_navigation_history(Some(10), cx);
|
||||
let path_prefix: Arc<str> = "".into();
|
||||
let path_prefix: Arc<str> = Arc::default();
|
||||
Task::ready(
|
||||
entries
|
||||
.into_iter()
|
||||
|
||||
@@ -219,7 +219,7 @@ impl SlashCommand for DocsSlashCommand {
|
||||
if index {
|
||||
// We don't need to hold onto this task, as the `IndexedDocsStore` will hold it
|
||||
// until it completes.
|
||||
let _ = store.clone().index(package.as_str().into());
|
||||
drop(store.clone().index(package.as_str().into()));
|
||||
}
|
||||
|
||||
let items = store.search(package).await;
|
||||
|
||||
@@ -29,7 +29,7 @@ impl FileSlashCommand {
|
||||
let workspace = workspace.read(cx);
|
||||
let project = workspace.project().read(cx);
|
||||
let entries = workspace.recent_navigation_history(Some(10), cx);
|
||||
let path_prefix: Arc<str> = "".into();
|
||||
let path_prefix: Arc<str> = Arc::default();
|
||||
Task::ready(
|
||||
entries
|
||||
.into_iter()
|
||||
|
||||
@@ -906,6 +906,7 @@ impl PromptEditor {
|
||||
},
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(1.3),
|
||||
@@ -943,7 +944,7 @@ impl TerminalTransaction {
|
||||
}
|
||||
|
||||
pub fn push(&mut self, hunk: String, cx: &mut AppContext) {
|
||||
// Ensure that the assistant cannot accidently execute commands that are streamed into the terminal
|
||||
// Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
|
||||
let input = hunk.replace(CARRIAGE_RETURN, " ");
|
||||
self.terminal
|
||||
.update(cx, |terminal, _| terminal.input(input));
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
[package]
|
||||
name = "assistant_tooling"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant_tooling.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
repair_json.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sum_tree.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
unindent.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-GPL
|
||||
@@ -1,85 +0,0 @@
|
||||
# Assistant Tooling
|
||||
|
||||
Bringing Language Model tool calling to GPUI.
|
||||
|
||||
This unlocks:
|
||||
|
||||
- **Structured Extraction** of model responses
|
||||
- **Validation** of model inputs
|
||||
- **Execution** of chosen tools
|
||||
|
||||
## Overview
|
||||
|
||||
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
|
||||
|
||||
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
|
||||
>
|
||||
> **Assistant**: "Sure, I can help with that. Let me see what I can find."
|
||||
>
|
||||
> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
|
||||
>
|
||||
> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
|
||||
>
|
||||
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
|
||||
|
||||
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`.
|
||||
|
||||
## Using the Tool Registry
|
||||
|
||||
```rust
|
||||
let mut tool_registry = ToolRegistry::new();
|
||||
tool_registry
|
||||
.register(WeatherTool { api_client },
|
||||
})
|
||||
.unwrap(); // You can only register one tool per name
|
||||
|
||||
let completion = cx.update(|cx| {
|
||||
CompletionProvider::get(cx).complete(
|
||||
model_name,
|
||||
messages,
|
||||
Vec::new(),
|
||||
1.0,
|
||||
// The definitions get passed directly to OpenAI when you want
|
||||
// the model to be able to call your tool
|
||||
tool_registry.definitions(),
|
||||
)
|
||||
});
|
||||
|
||||
let mut stream = completion?.await?;
|
||||
|
||||
let mut message = AssistantMessage::new();
|
||||
|
||||
while let Some(delta) = stream.next().await {
|
||||
// As messages stream in, you'll get both assistant content
|
||||
if let Some(content) = &delta.content {
|
||||
message
|
||||
.body
|
||||
.update(cx, |message, cx| message.append(&content, cx));
|
||||
}
|
||||
|
||||
// And tool calls!
|
||||
for tool_call_delta in delta.tool_calls {
|
||||
let index = tool_call_delta.index as usize;
|
||||
if index >= message.tool_calls.len() {
|
||||
message.tool_calls.resize_with(index + 1, Default::default);
|
||||
}
|
||||
let tool_call = &mut message.tool_calls[index];
|
||||
|
||||
// Build up an ID
|
||||
if let Some(id) = &tool_call_delta.id {
|
||||
tool_call.id.push_str(id);
|
||||
}
|
||||
|
||||
tool_registry.update_tool_call(
|
||||
tool_call,
|
||||
tool_call_delta.name.as_deref(),
|
||||
tool_call_delta.arguments.as_deref(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Once the stream of tokens is complete, you can exexute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task<Result<()>>`.
|
||||
|
||||
As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`.
|
||||
@@ -1,13 +0,0 @@
|
||||
mod attachment_registry;
|
||||
mod project_context;
|
||||
mod tool_registry;
|
||||
|
||||
pub use attachment_registry::{
|
||||
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
|
||||
UserAttachment,
|
||||
};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition,
|
||||
ToolRegistry, ToolView,
|
||||
};
|
||||
@@ -1,234 +0,0 @@
|
||||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use futures::future::join_all;
|
||||
use gpui::{AnyView, Render, Task, View, WindowContext};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub struct AttachmentRegistry {
|
||||
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
|
||||
}
|
||||
|
||||
pub trait AttachmentOutput {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
}
|
||||
|
||||
pub trait LanguageModelAttachment {
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
type View: Render + AttachmentOutput;
|
||||
|
||||
fn name(&self) -> Arc<str>;
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
/// A collected attachment from running an attachment tool
|
||||
pub struct UserAttachment {
|
||||
pub view: AnyView,
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedUserAttachment {
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
}
|
||||
|
||||
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
||||
struct RegisteredAttachment {
|
||||
name: Arc<str>,
|
||||
enabled: AtomicBool,
|
||||
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
||||
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
|
||||
}
|
||||
|
||||
impl AttachmentRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registered_attachments: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
|
||||
let attachment = Arc::new(attachment);
|
||||
|
||||
let call = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = attachment.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let serialized_output =
|
||||
result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
|
||||
let view = cx.update(|cx| attachment.view(result, cx))?;
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: attachment.name(),
|
||||
view: view.into(),
|
||||
generate_fn: generate::<A>,
|
||||
serialized_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let deserialize = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
|
||||
let serialized_output = saved_attachment.serialized_output.clone();
|
||||
let output = match &serialized_output {
|
||||
Ok(serialized_output) => {
|
||||
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
|
||||
}
|
||||
Err(error) => Err(anyhow!("{error}")),
|
||||
};
|
||||
let view = attachment.view(output, cx).into();
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: saved_attachment.name.clone(),
|
||||
view,
|
||||
serialized_output,
|
||||
generate_fn: generate::<A>,
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
self.registered_attachments.insert(
|
||||
TypeId::of::<A>(),
|
||||
RegisteredAttachment {
|
||||
name: attachment.name(),
|
||||
call,
|
||||
deserialize,
|
||||
enabled: AtomicBool::new(true),
|
||||
},
|
||||
);
|
||||
return;
|
||||
|
||||
fn generate<T: LanguageModelAttachment>(
|
||||
view: AnyView,
|
||||
project: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
view.downcast::<T::View>()
|
||||
.unwrap()
|
||||
.update(cx, |view, cx| T::View::generate(view, project, cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
|
||||
&self,
|
||||
is_enabled: bool,
|
||||
) {
|
||||
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
|
||||
attachment.enabled.store(is_enabled, SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
|
||||
if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
|
||||
attachment.enabled.load(SeqCst)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call<A: LanguageModelAttachment + 'static>(
|
||||
&self,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<UserAttachment>> {
|
||||
let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
|
||||
return Task::ready(Err(anyhow!("no attachment tool")));
|
||||
};
|
||||
|
||||
(attachment.call)(cx)
|
||||
}
|
||||
|
||||
pub fn call_all_attachment_tools(
|
||||
self: Arc<Self>,
|
||||
cx: &mut WindowContext<'_>,
|
||||
) -> Task<Result<Vec<UserAttachment>>> {
|
||||
let this = self.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let attachment_tasks = cx.update(|cx| {
|
||||
let mut tasks = Vec::new();
|
||||
for attachment in this
|
||||
.registered_attachments
|
||||
.values()
|
||||
.filter(|attachment| attachment.enabled.load(SeqCst))
|
||||
{
|
||||
tasks.push((attachment.call)(cx))
|
||||
}
|
||||
|
||||
tasks
|
||||
})?;
|
||||
|
||||
let attachments = join_all(attachment_tasks.into_iter()).await;
|
||||
|
||||
Ok(attachments
|
||||
.into_iter()
|
||||
.filter_map(|attachment| attachment.log_err())
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_user_attachment(
|
||||
&self,
|
||||
user_attachment: &UserAttachment,
|
||||
) -> SavedUserAttachment {
|
||||
SavedUserAttachment {
|
||||
name: user_attachment.name.clone(),
|
||||
serialized_output: user_attachment.serialized_output.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_user_attachment(
|
||||
&self,
|
||||
saved_user_attachment: SavedUserAttachment,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<UserAttachment> {
|
||||
if let Some(registered_attachment) = self
|
||||
.registered_attachments
|
||||
.values()
|
||||
.find(|attachment| attachment.name == saved_user_attachment.name)
|
||||
{
|
||||
(registered_attachment.deserialize)(&saved_user_attachment, cx)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"no attachment tool for name {}",
|
||||
saved_user_attachment.name
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAttachment {
|
||||
pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
|
||||
let result = (self.generate_fn)(self.view.clone(), output, cx);
|
||||
if result.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,296 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AppContext, Model, Task, WeakModel};
|
||||
use project::{Fs, Project, ProjectPath, Worktree};
|
||||
use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
|
||||
use sum_tree::TreeMap;
|
||||
|
||||
pub struct ProjectContext {
|
||||
files: TreeMap<ProjectPath, PathState>,
|
||||
project: WeakModel<Project>,
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum PathState {
|
||||
PathOnly,
|
||||
EntireFile,
|
||||
Excerpts { ranges: Vec<Range<usize>> },
|
||||
}
|
||||
|
||||
impl ProjectContext {
|
||||
pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
|
||||
Self {
|
||||
files: TreeMap::default(),
|
||||
fs,
|
||||
project,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_path(&mut self, project_path: ProjectPath) {
|
||||
if self.files.get(&project_path).is_none() {
|
||||
self.files.insert(project_path, PathState::PathOnly);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
|
||||
let previous_state = self
|
||||
.files
|
||||
.get(&project_path)
|
||||
.unwrap_or(&PathState::PathOnly);
|
||||
|
||||
let mut ranges = match previous_state {
|
||||
PathState::EntireFile => return,
|
||||
PathState::PathOnly => Vec::new(),
|
||||
PathState::Excerpts { ranges } => ranges.to_vec(),
|
||||
};
|
||||
|
||||
for new_range in new_ranges {
|
||||
let ix = ranges.binary_search_by(|probe| {
|
||||
if probe.end < new_range.start {
|
||||
Ordering::Less
|
||||
} else if probe.start > new_range.end {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
});
|
||||
|
||||
match ix {
|
||||
Ok(mut ix) => {
|
||||
let existing = &mut ranges[ix];
|
||||
existing.start = existing.start.min(new_range.start);
|
||||
existing.end = existing.end.max(new_range.end);
|
||||
while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
|
||||
ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
|
||||
ranges.remove(ix + 1);
|
||||
}
|
||||
while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
|
||||
ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
|
||||
ranges.remove(ix - 1);
|
||||
ix -= 1;
|
||||
}
|
||||
}
|
||||
Err(ix) => {
|
||||
ranges.insert(ix, new_range.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.files
|
||||
.insert(project_path, PathState::Excerpts { ranges });
|
||||
}
|
||||
|
||||
pub fn add_file(&mut self, project_path: ProjectPath) {
|
||||
self.files.insert(project_path, PathState::EntireFile);
|
||||
}
|
||||
|
||||
pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
|
||||
let project = self
|
||||
.project
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("project dropped"));
|
||||
let files = self.files.clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|cx| async move {
|
||||
let project = project?;
|
||||
let mut result = "project structure:\n".to_string();
|
||||
|
||||
let mut last_worktree: Option<Model<Worktree>> = None;
|
||||
for (project_path, path_state) in files.iter() {
|
||||
if let Some(worktree) = &last_worktree {
|
||||
if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
|
||||
last_worktree = None;
|
||||
}
|
||||
}
|
||||
|
||||
let worktree;
|
||||
if let Some(last_worktree) = &last_worktree {
|
||||
worktree = last_worktree.clone();
|
||||
} else if let Some(tree) = project.read_with(&cx, |project, cx| {
|
||||
project.worktree_for_id(project_path.worktree_id, cx)
|
||||
})? {
|
||||
worktree = tree;
|
||||
last_worktree = Some(worktree.clone());
|
||||
let worktree_name =
|
||||
worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
|
||||
writeln!(&mut result, "# {}", worktree_name).unwrap();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
|
||||
let path = &project_path.path;
|
||||
writeln!(&mut result, "## {}", path.display()).unwrap();
|
||||
|
||||
match path_state {
|
||||
PathState::PathOnly => {}
|
||||
PathState::EntireFile => {
|
||||
let text = fs.load(&worktree_abs_path.join(&path)).await?;
|
||||
writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
|
||||
}
|
||||
PathState::Excerpts { ranges } => {
|
||||
let text = fs.load(&worktree_abs_path.join(&path)).await?;
|
||||
|
||||
writeln!(&mut result, "~~~").unwrap();
|
||||
|
||||
// Assumption: ranges are in order, not overlapping
|
||||
let mut prev_range_end = 0;
|
||||
for range in ranges {
|
||||
if range.start > prev_range_end {
|
||||
writeln!(&mut result, "...").unwrap();
|
||||
prev_range_end = range.end;
|
||||
}
|
||||
|
||||
let mut start = range.start;
|
||||
let mut end = range.end.min(text.len());
|
||||
while !text.is_char_boundary(start) {
|
||||
start += 1;
|
||||
}
|
||||
while !text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
result.push_str(&text[start..end]);
|
||||
if !result.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
if prev_range_end < text.len() {
|
||||
writeln!(&mut result, "...").unwrap();
|
||||
}
|
||||
|
||||
writeln!(&mut result, "~~~").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use unindent::Unindent as _;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_system_message_generation(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let file_3_contents = r#"
|
||||
fn test1() {}
|
||||
fn test2() {}
|
||||
fn test3() {}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/code",
|
||||
json!({
|
||||
"root1": {
|
||||
"lib": {
|
||||
"file1.rs": "mod example;",
|
||||
"file2.rs": "",
|
||||
},
|
||||
"test": {
|
||||
"file3.rs": file_3_contents,
|
||||
}
|
||||
},
|
||||
"root2": {
|
||||
"src": {
|
||||
"main.rs": ""
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(
|
||||
fs.clone(),
|
||||
["/code/root1".as_ref(), "/code/root2".as_ref()],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let worktree_ids = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id())
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let mut ax = ProjectContext::new(project.downgrade(), fs);
|
||||
|
||||
ax.add_file(ProjectPath {
|
||||
worktree_id: worktree_ids[0],
|
||||
path: Path::new("lib/file1.rs").into(),
|
||||
});
|
||||
|
||||
let message = cx
|
||||
.update(|cx| ax.generate_system_message(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
r#"
|
||||
project structure:
|
||||
# root1
|
||||
## lib/file1.rs
|
||||
~~~
|
||||
mod example;
|
||||
~~~
|
||||
"#
|
||||
.unindent(),
|
||||
message
|
||||
);
|
||||
|
||||
ax.add_excerpts(
|
||||
ProjectPath {
|
||||
worktree_id: worktree_ids[0],
|
||||
path: Path::new("test/file3.rs").into(),
|
||||
},
|
||||
&[
|
||||
file_3_contents.find("fn test2").unwrap()
|
||||
..file_3_contents.find("fn test3").unwrap(),
|
||||
],
|
||||
);
|
||||
|
||||
let message = cx
|
||||
.update(|cx| ax.generate_system_message(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
r#"
|
||||
project structure:
|
||||
# root1
|
||||
## lib/file1.rs
|
||||
~~~
|
||||
mod example;
|
||||
~~~
|
||||
## test/file3.rs
|
||||
~~~
|
||||
...
|
||||
fn test2() {}
|
||||
...
|
||||
~~~
|
||||
"#
|
||||
.unindent(),
|
||||
message
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,526 +0,0 @@
|
||||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
|
||||
use repair_json::repair;
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
fmt::Display,
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
};
|
||||
use ui::ViewContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
registered_tools: HashMap<String, RegisteredTool>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
state: ToolFunctionCallState,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum ToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
KnownTool(Box<dyn InternalToolView>),
|
||||
ExecutedTool(Box<dyn InternalToolView>),
|
||||
}
|
||||
|
||||
trait InternalToolView {
|
||||
fn view(&self) -> AnyView;
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
||||
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct SavedToolFunctionCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
state: SavedToolFunctionCallState,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
enum SavedToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
KnownTool,
|
||||
ExecutedTool(Box<RawValue>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: RootSchema,
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool {
|
||||
type View: ToolView;
|
||||
|
||||
/// Returns the name of the tool.
|
||||
///
|
||||
/// This name is exposed to the language model to allow the model to pick
|
||||
/// which tools to use. As this name is used to identify the tool within a
|
||||
/// tool registry, it should be unique.
|
||||
fn name(&self) -> String;
|
||||
|
||||
/// Returns the description of the tool.
|
||||
///
|
||||
/// This can be used to _prompt_ the model as to what the tool does.
|
||||
fn description(&self) -> String;
|
||||
|
||||
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
||||
fn definition(&self) -> ToolFunctionDefinition {
|
||||
let root_schema = schema_for!(<Self::View as ToolView>::Input);
|
||||
|
||||
ToolFunctionDefinition {
|
||||
name: self.name(),
|
||||
description: self.description(),
|
||||
parameters: root_schema,
|
||||
}
|
||||
}
|
||||
|
||||
/// A view of the output of running the tool, for displaying to the user.
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
pub trait ToolView: Render {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type SerializedState: DeserializeOwned + Serialize;
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
|
||||
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
|
||||
|
||||
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
struct RegisteredTool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registered_tools: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
|
||||
for tool in self.registered_tools.values() {
|
||||
if tool.type_id == TypeId::of::<T>() {
|
||||
tool.enabled.store(is_enabled, SeqCst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
|
||||
for tool in self.registered_tools.values() {
|
||||
if tool.type_id == TypeId::of::<T>() {
|
||||
return tool.enabled.load(SeqCst);
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
|
||||
self.registered_tools
|
||||
.values()
|
||||
.filter(|tool| tool.enabled.load(SeqCst))
|
||||
.map(|tool| tool.definition.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&self,
|
||||
call: &mut ToolFunctionCall,
|
||||
name: Option<&str>,
|
||||
arguments: Option<&str>,
|
||||
cx: &mut WindowContext,
|
||||
) {
|
||||
if let Some(name) = name {
|
||||
call.name.push_str(name);
|
||||
}
|
||||
if let Some(arguments) = arguments {
|
||||
if call.arguments.is_empty() {
|
||||
if let Some(tool) = self.registered_tools.get(&call.name) {
|
||||
let view = (tool.build_view)(cx);
|
||||
call.state = ToolFunctionCallState::KnownTool(view);
|
||||
} else {
|
||||
call.state = ToolFunctionCallState::NoSuchTool;
|
||||
}
|
||||
}
|
||||
call.arguments.push_str(arguments);
|
||||
|
||||
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
||||
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
||||
view.try_set_input(&repaired_arguments, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_tool_call(
|
||||
&self,
|
||||
tool_call: &mut ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
|
||||
let task = view.execute(cx);
|
||||
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
|
||||
Some(task)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Option<AnyElement> {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
Some(ui::Label::new("No such tool").into_any_element())
|
||||
}
|
||||
ToolFunctionCallState::Initializing => None,
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
Some(view.view().into_any_element())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn content_for_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
project_context: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::Initializing => String::new(),
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
format!("No such tool: {}", tool_call.name)
|
||||
}
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
view.generate(project_context, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(
|
||||
&self,
|
||||
call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<SavedToolFunctionCall> {
|
||||
Ok(SavedToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
state: match &call.state {
|
||||
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
|
||||
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
|
||||
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
|
||||
ToolFunctionCallState::ExecutedTool(view) => {
|
||||
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize_tool_call(
|
||||
&self,
|
||||
call: &SavedToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<ToolFunctionCall> {
|
||||
let Some(tool) = self.registered_tools.get(&call.name) else {
|
||||
return Err(anyhow!("no such tool {}", call.name));
|
||||
};
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
state: match &call.state {
|
||||
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
|
||||
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
|
||||
SavedToolFunctionCallState::KnownTool => {
|
||||
log::error!("Deserialized tool that had not executed");
|
||||
let view = (tool.build_view)(cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
ToolFunctionCallState::KnownTool(view)
|
||||
}
|
||||
SavedToolFunctionCallState::ExecutedTool(output) => {
|
||||
let view = (tool.build_view)(cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
view.deserialize_output(output, cx)?;
|
||||
ToolFunctionCallState::ExecutedTool(view)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||
let name = tool.name();
|
||||
let registered_tool = RegisteredTool {
|
||||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
enabled: AtomicBool::new(true),
|
||||
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
|
||||
};
|
||||
|
||||
let previous = self.registered_tools.insert(name.clone(), registered_tool);
|
||||
if previous.is_some() {
|
||||
return Err(anyhow!("already registered a tool with name {}", name));
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ToolView> InternalToolView for View<T> {
|
||||
fn view(&self) -> AnyView {
|
||||
self.clone().into()
|
||||
}
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||
self.update(cx, |view, cx| view.generate(project, cx))
|
||||
}
|
||||
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
||||
self.update(cx, |view, cx| {
|
||||
view.set_input(input, cx);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
|
||||
self.update(cx, |view, cx| view.execute(cx))
|
||||
}
|
||||
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
|
||||
let output = self.update(cx, |view, cx| view.serialize(cx));
|
||||
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
|
||||
}
|
||||
|
||||
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
|
||||
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
|
||||
self.update(cx, |view, cx| view.deserialize(state, cx))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ToolFunctionDefinition {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let schema = serde_json::to_string(&self.parameters).ok();
|
||||
let schema = schema.unwrap_or("None".to_string());
|
||||
write!(f, "Name: {}:\n", self.name)?;
|
||||
write!(f, "Description: {}\n", self.description)?;
|
||||
write!(f, "Parameters: {}", schema)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||
use gpui::{EmptyView, View};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema)]
|
||||
struct WeatherQuery {
|
||||
location: String,
|
||||
unit: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
|
||||
struct WeatherResult {
|
||||
location: String,
|
||||
temperature: f64,
|
||||
unit: String,
|
||||
}
|
||||
|
||||
struct WeatherView {
|
||||
input: Option<WeatherQuery>,
|
||||
result: Option<WeatherResult>,
|
||||
|
||||
// Fake API call
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
impl WeatherView {
|
||||
fn new(current_weather: WeatherResult) -> Self {
|
||||
Self {
|
||||
input: None,
|
||||
result: None,
|
||||
current_weather,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for WeatherView {
|
||||
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
|
||||
match self.result {
|
||||
Some(ref result) => div()
|
||||
.child(format!("temperature: {}", result.temperature))
|
||||
.into_any_element(),
|
||||
None => div().child("Calculating weather...").into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolView for WeatherView {
|
||||
type Input = WeatherQuery;
|
||||
|
||||
type SerializedState = WeatherResult;
|
||||
|
||||
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
|
||||
serde_json::to_string(&self.result).unwrap()
|
||||
}
|
||||
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
self.input = Some(input);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
let input = self.input.as_ref().unwrap();
|
||||
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
self.result = Some(weather);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
self.current_weather.clone()
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
self.current_weather = output;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelTool for WeatherTool {
|
||||
type View = WeatherView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"get_current_weather".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Fetches the current weather for a given location.".to_string()
|
||||
}
|
||||
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
||||
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry
|
||||
.register(WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let definitions = registry.definitions();
|
||||
assert_eq!(
|
||||
definitions,
|
||||
[ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: serde_json::from_value(json!({
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
}))
|
||||
.unwrap(),
|
||||
}]
|
||||
);
|
||||
|
||||
let mut call = ToolFunctionCall {
|
||||
id: "the-id".to_string(),
|
||||
name: "get_cur".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let task = cx.update(|cx| {
|
||||
registry.update_tool_call(
|
||||
&mut call,
|
||||
Some("rent_weather"),
|
||||
Some(r#"{"location": "San Francisco","#),
|
||||
cx,
|
||||
);
|
||||
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
|
||||
registry.execute_tool_call(&mut call, cx).unwrap()
|
||||
});
|
||||
task.await.unwrap();
|
||||
|
||||
match &call.state {
|
||||
ToolFunctionCallState::ExecutedTool(_view) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -493,7 +493,7 @@ impl Room {
|
||||
// we leave the room and return an error.
|
||||
if let Some(this) = this.upgrade() {
|
||||
log::info!("reconnection failed, leaving room");
|
||||
let _ = this.update(&mut cx, |this, cx| this.leave(cx))?;
|
||||
let _ = this.update(&mut cx, |this, cx| this.leave(cx))?.await?;
|
||||
}
|
||||
Err(anyhow!(
|
||||
"can't reconnect to room: client failed to re-establish connection"
|
||||
@@ -942,7 +942,7 @@ impl Room {
|
||||
this.pending_room_update.take();
|
||||
if this.should_leave() {
|
||||
log::info!("room is empty, leaving");
|
||||
let _ = this.leave(cx);
|
||||
let _ = this.leave(cx).detach();
|
||||
}
|
||||
|
||||
this.user_store.update(cx, |user_store, cx| {
|
||||
|
||||
@@ -7,8 +7,9 @@ pub mod user;
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use async_recursion::async_recursion;
|
||||
use async_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::Error as WebsocketError,
|
||||
http::{Request, StatusCode},
|
||||
http::{HeaderValue, Request, StatusCode},
|
||||
};
|
||||
use clock::SystemClock;
|
||||
use collections::HashMap;
|
||||
@@ -235,6 +236,8 @@ pub enum EstablishConnectionError {
|
||||
#[error("{0}")]
|
||||
Http(#[from] http_client::Error),
|
||||
#[error("{0}")]
|
||||
InvalidHeaderValue(#[from] async_tungstenite::tungstenite::http::header::InvalidHeaderValue),
|
||||
#[error("{0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
Websocket(#[from] async_tungstenite::tungstenite::http::Error),
|
||||
@@ -1159,19 +1162,24 @@ impl Client {
|
||||
.ok()
|
||||
.unwrap_or_default();
|
||||
|
||||
let request = Request::builder()
|
||||
.header("Authorization", credentials.authorization_header())
|
||||
.header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
|
||||
.header("x-zed-app-version", app_version)
|
||||
.header(
|
||||
"x-zed-release-channel",
|
||||
release_channel.map(|r| r.dev_name()).unwrap_or("unknown"),
|
||||
);
|
||||
|
||||
let http = self.http.clone();
|
||||
let credentials = credentials.clone();
|
||||
let rpc_url = self.rpc_url(http, release_channel);
|
||||
cx.background_executor().spawn(async move {
|
||||
use HttpOrHttps::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum HttpOrHttps {
|
||||
Http,
|
||||
Https,
|
||||
}
|
||||
|
||||
let mut rpc_url = rpc_url.await?;
|
||||
let url_scheme = match rpc_url.scheme() {
|
||||
"https" => Https,
|
||||
"http" => Http,
|
||||
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
|
||||
};
|
||||
let rpc_host = rpc_url
|
||||
.host_str()
|
||||
.zip(rpc_url.port_or_known_default())
|
||||
@@ -1180,10 +1188,37 @@ impl Client {
|
||||
|
||||
log::info!("connected to rpc endpoint {}", rpc_url);
|
||||
|
||||
match rpc_url.scheme() {
|
||||
"https" => {
|
||||
rpc_url.set_scheme("wss").unwrap();
|
||||
let request = request.uri(rpc_url.as_str()).body(())?;
|
||||
rpc_url
|
||||
.set_scheme(match url_scheme {
|
||||
Https => "wss",
|
||||
Http => "ws",
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// We call `into_client_request` to let `tungstenite` construct the WebSocket request
|
||||
// for us from the RPC URL.
|
||||
//
|
||||
// Among other things, it will generate and set a `Sec-WebSocket-Key` header for us.
|
||||
let mut request = rpc_url.into_client_request()?;
|
||||
|
||||
// We then modify the request to add our desired headers.
|
||||
let request_headers = request.headers_mut();
|
||||
request_headers.insert(
|
||||
"Authorization",
|
||||
HeaderValue::from_str(&credentials.authorization_header())?,
|
||||
);
|
||||
request_headers.insert(
|
||||
"x-zed-protocol-version",
|
||||
HeaderValue::from_str(&rpc::PROTOCOL_VERSION.to_string())?,
|
||||
);
|
||||
request_headers.insert("x-zed-app-version", HeaderValue::from_str(&app_version)?);
|
||||
request_headers.insert(
|
||||
"x-zed-release-channel",
|
||||
HeaderValue::from_str(&release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
|
||||
);
|
||||
|
||||
match url_scheme {
|
||||
Https => {
|
||||
let (stream, _) =
|
||||
async_tungstenite::async_std::client_async_tls(request, stream).await?;
|
||||
Ok(Connection::new(
|
||||
@@ -1192,9 +1227,7 @@ impl Client {
|
||||
.sink_map_err(|error| anyhow!(error)),
|
||||
))
|
||||
}
|
||||
"http" => {
|
||||
rpc_url.set_scheme("ws").unwrap();
|
||||
let request = request.uri(rpc_url.as_str()).body(())?;
|
||||
Http => {
|
||||
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
|
||||
Ok(Connection::new(
|
||||
stream
|
||||
@@ -1202,7 +1235,6 @@ impl Client {
|
||||
.sink_map_err(|error| anyhow!(error)),
|
||||
))
|
||||
}
|
||||
_ => Err(anyhow!("invalid rpc url: {}", rpc_url))?,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use sysinfo::{CpuRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CpuEvent, EditEvent,
|
||||
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent,
|
||||
MemoryEvent, SettingEvent,
|
||||
MemoryEvent, ReplEvent, SettingEvent,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
#[cfg(not(debug_assertions))]
|
||||
@@ -531,6 +531,21 @@ impl Telemetry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report_repl_event(
|
||||
self: &Arc<Self>,
|
||||
kernel_language: String,
|
||||
kernel_status: String,
|
||||
repl_session_id: String,
|
||||
) {
|
||||
let event = Event::Repl(ReplEvent {
|
||||
kernel_language,
|
||||
kernel_status,
|
||||
repl_session_id,
|
||||
});
|
||||
|
||||
self.report_event(event)
|
||||
}
|
||||
|
||||
fn report_event(self: &Arc<Self>, event: Event) {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ chrono.workspace = true
|
||||
clock.workspace = true
|
||||
clickhouse.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap = "5.4"
|
||||
dashmap.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
@@ -47,7 +47,7 @@ prost.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
rpc.workspace = true
|
||||
scrypt = "0.7"
|
||||
scrypt = "0.11"
|
||||
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
|
||||
semantic_version.workspace = true
|
||||
semver.workspace = true
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use rpc::proto;
|
||||
use util::ResultExt as _;
|
||||
|
||||
pub fn language_model_request_to_open_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<open_ai::Request> {
|
||||
Ok(open_ai::Request {
|
||||
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message: proto::LanguageModelRequestMessage| {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
|
||||
let openai_message = match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
|
||||
content: message.content,
|
||||
},
|
||||
proto::LanguageModelRole::LanguageModelAssistant => {
|
||||
open_ai::RequestMessage::Assistant {
|
||||
content: Some(message.content),
|
||||
tool_calls: message
|
||||
.tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|call| {
|
||||
Some(open_ai::ToolCall {
|
||||
id: call.id,
|
||||
content: match call.variant? {
|
||||
proto::tool_call::Variant::Function(f) => {
|
||||
open_ai::ToolCallContent::Function {
|
||||
function: open_ai::FunctionContent {
|
||||
name: f.name,
|
||||
arguments: f.arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
proto::LanguageModelRole::LanguageModelSystem => {
|
||||
open_ai::RequestMessage::System {
|
||||
content: message.content,
|
||||
}
|
||||
}
|
||||
proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
|
||||
tool_call_id: message
|
||||
.tool_call_id
|
||||
.ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
|
||||
content: message.content,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(openai_message)
|
||||
})
|
||||
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
Some(match tool.variant? {
|
||||
proto::chat_completion_tool::Variant::Function(f) => {
|
||||
open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: f.name,
|
||||
description: f.description,
|
||||
parameters: if let Some(params) = &f.parameters {
|
||||
Some(
|
||||
serde_json::from_str(params)
|
||||
.context("failed to deserialize tool parameters")
|
||||
.log_err()?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_to_google_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<google_ai::GenerateContentRequest> {
|
||||
Ok(google_ai::GenerateContentRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
generation_config: None,
|
||||
safety_settings: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_message_to_google_ai(
|
||||
message: proto::LanguageModelRequestMessage,
|
||||
) -> Result<google_ai::Content> {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
|
||||
Ok(google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: message.content,
|
||||
})],
|
||||
role: match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
|
||||
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelTool => {
|
||||
Err(anyhow!("we don't handle tool calls with google ai yet"))?
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn count_tokens_request_to_google_ai(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
) -> Result<google_ai::CountTokensRequest> {
|
||||
Ok(google_ai::CountTokensRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
})
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod contributors;
|
||||
pub mod events;
|
||||
pub mod extensions;
|
||||
pub mod ips_file;
|
||||
@@ -5,13 +6,13 @@ pub mod slack;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
db::{ContributorSelector, User, UserId},
|
||||
db::{User, UserId},
|
||||
rpc, AppState, Error, Result,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{self, Path, Query},
|
||||
extract::{Path, Query},
|
||||
http::{self, Request, StatusCode},
|
||||
middleware::{self, Next},
|
||||
response::IntoResponse,
|
||||
@@ -19,7 +20,6 @@ use axum::{
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use axum_extra::response::ErasedJson;
|
||||
use chrono::SecondsFormat;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceBuilder;
|
||||
@@ -31,8 +31,7 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
|
||||
.route("/user", get(get_authenticated_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
||||
.route("/contributors", get(get_contributors).post(add_contributor))
|
||||
.route("/contributor", get(check_is_contributor))
|
||||
.merge(contributors::router())
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(Extension(state))
|
||||
@@ -126,66 +125,6 @@ async fn get_rpc_server_snapshot(
|
||||
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
|
||||
}
|
||||
|
||||
async fn get_contributors(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<String>>> {
|
||||
Ok(Json(app.db.get_contributors().await?))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CheckIsContributorParams {
|
||||
github_user_id: Option<i32>,
|
||||
github_login: Option<String>,
|
||||
}
|
||||
|
||||
impl CheckIsContributorParams {
|
||||
fn as_contributor_selector(self) -> Result<ContributorSelector> {
|
||||
if let Some(github_user_id) = self.github_user_id {
|
||||
return Ok(ContributorSelector::GitHubUserId { github_user_id });
|
||||
}
|
||||
|
||||
if let Some(github_login) = self.github_login {
|
||||
return Ok(ContributorSelector::GitHubLogin { github_login });
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"must be one of `github_user_id` or `github_login`."
|
||||
))?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CheckIsContributorResponse {
|
||||
signed_at: Option<String>,
|
||||
}
|
||||
|
||||
async fn check_is_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<CheckIsContributorParams>,
|
||||
) -> Result<Json<CheckIsContributorResponse>> {
|
||||
let params = params.as_contributor_selector()?;
|
||||
Ok(Json(CheckIsContributorResponse {
|
||||
signed_at: app
|
||||
.db
|
||||
.get_contributor_sign_timestamp(¶ms)
|
||||
.await?
|
||||
.map(|ts| ts.and_utc().to_rfc3339_opts(SecondsFormat::Millis, true)),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn add_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(params): extract::Json<AuthenticatedUserParams>,
|
||||
) -> Result<()> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
app.db
|
||||
.add_contributor(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
initial_channel_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CreateAccessTokenQueryParams {
|
||||
public_key: String,
|
||||
|
||||
121
crates/collab/src/api/contributors.rs
Normal file
121
crates/collab/src/api/contributors.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::{
|
||||
extract::{self, Query},
|
||||
routing::get,
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{NaiveDateTime, SecondsFormat};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::AuthenticatedUserParams;
|
||||
use crate::db::ContributorSelector;
|
||||
use crate::{AppState, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/contributors", get(get_contributors).post(add_contributor))
|
||||
.route("/contributor", get(check_is_contributor))
|
||||
}
|
||||
|
||||
async fn get_contributors(Extension(app): Extension<Arc<AppState>>) -> Result<Json<Vec<String>>> {
|
||||
Ok(Json(app.db.get_contributors().await?))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CheckIsContributorParams {
|
||||
github_user_id: Option<i32>,
|
||||
github_login: Option<String>,
|
||||
}
|
||||
|
||||
impl CheckIsContributorParams {
|
||||
fn as_contributor_selector(self) -> Result<ContributorSelector> {
|
||||
if let Some(github_user_id) = self.github_user_id {
|
||||
return Ok(ContributorSelector::GitHubUserId { github_user_id });
|
||||
}
|
||||
|
||||
if let Some(github_login) = self.github_login {
|
||||
return Ok(ContributorSelector::GitHubLogin { github_login });
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"must be one of `github_user_id` or `github_login`."
|
||||
))?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CheckIsContributorResponse {
|
||||
signed_at: Option<String>,
|
||||
}
|
||||
|
||||
async fn check_is_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<CheckIsContributorParams>,
|
||||
) -> Result<Json<CheckIsContributorResponse>> {
|
||||
let params = params.as_contributor_selector()?;
|
||||
|
||||
if RenovateBot::is_renovate_bot(¶ms) {
|
||||
return Ok(Json(CheckIsContributorResponse {
|
||||
signed_at: Some(
|
||||
RenovateBot::created_at()
|
||||
.and_utc()
|
||||
.to_rfc3339_opts(SecondsFormat::Millis, true),
|
||||
),
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(CheckIsContributorResponse {
|
||||
signed_at: app
|
||||
.db
|
||||
.get_contributor_sign_timestamp(¶ms)
|
||||
.await?
|
||||
.map(|ts| ts.and_utc().to_rfc3339_opts(SecondsFormat::Millis, true)),
|
||||
}))
|
||||
}
|
||||
|
||||
/// The Renovate bot GitHub user (`renovate[bot]`).
|
||||
///
|
||||
/// https://api.github.com/users/renovate[bot]
|
||||
struct RenovateBot;
|
||||
|
||||
impl RenovateBot {
|
||||
const LOGIN: &'static str = "renovate[bot]";
|
||||
const USER_ID: i32 = 29139614;
|
||||
|
||||
/// Returns the `created_at` timestamp for the Renovate bot user.
|
||||
fn created_at() -> &'static NaiveDateTime {
|
||||
static CREATED_AT: OnceLock<NaiveDateTime> = OnceLock::new();
|
||||
CREATED_AT.get_or_init(|| {
|
||||
chrono::DateTime::parse_from_rfc3339("2017-06-02T07:04:12Z")
|
||||
.expect("failed to parse 'created_at' for 'renovate[bot]'")
|
||||
.naive_utc()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns whether the given contributor selector corresponds to the Renovate bot user.
|
||||
fn is_renovate_bot(contributor: &ContributorSelector) -> bool {
|
||||
match contributor {
|
||||
ContributorSelector::GitHubLogin { github_login } => github_login == Self::LOGIN,
|
||||
ContributorSelector::GitHubUserId { github_user_id } => {
|
||||
github_user_id == &Self::USER_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn add_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(params): extract::Json<AuthenticatedUserParams>,
|
||||
) -> Result<()> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
app.db
|
||||
.add_contributor(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
initial_channel_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -16,7 +16,7 @@ use sha2::{Digest, Sha256};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, CallEvent, CpuEvent, EditEvent, EditorEvent, Event,
|
||||
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent,
|
||||
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent, ReplEvent,
|
||||
SettingEvent,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
@@ -518,6 +518,13 @@ pub async fn post_events(
|
||||
checksum_matched,
|
||||
))
|
||||
}
|
||||
Event::Repl(event) => to_upload.repl_events.push(ReplEventRow::from_event(
|
||||
event.clone(),
|
||||
&wrapper,
|
||||
&request_body,
|
||||
first_event_at,
|
||||
checksum_matched,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -542,6 +549,7 @@ struct ToUpload {
|
||||
extension_events: Vec<ExtensionEventRow>,
|
||||
edit_events: Vec<EditEventRow>,
|
||||
action_events: Vec<ActionEventRow>,
|
||||
repl_events: Vec<ReplEventRow>,
|
||||
}
|
||||
|
||||
impl ToUpload {
|
||||
@@ -617,6 +625,11 @@ impl ToUpload {
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{ACTION_EVENTS_TABLE}'"))?;
|
||||
|
||||
const REPL_EVENTS_TABLE: &str = "repl_events";
|
||||
Self::upload_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{REPL_EVENTS_TABLE}'"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -625,22 +638,24 @@ impl ToUpload {
|
||||
rows: &[T],
|
||||
clickhouse_client: &clickhouse::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
if !rows.is_empty() {
|
||||
let mut insert = clickhouse_client.insert(table)?;
|
||||
|
||||
for event in rows {
|
||||
insert.write(event).await?;
|
||||
}
|
||||
|
||||
insert.end().await?;
|
||||
|
||||
let event_count = rows.len();
|
||||
log::info!(
|
||||
"wrote {event_count} {event_specifier} to '{table}'",
|
||||
event_specifier = if event_count == 1 { "event" } else { "events" }
|
||||
);
|
||||
if rows.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut insert = clickhouse_client.insert(table)?;
|
||||
|
||||
for event in rows {
|
||||
insert.write(event).await?;
|
||||
}
|
||||
|
||||
insert.end().await?;
|
||||
|
||||
let event_count = rows.len();
|
||||
log::info!(
|
||||
"wrote {event_count} {event_specifier} to '{table}'",
|
||||
event_specifier = if event_count == 1 { "event" } else { "events" }
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1189,6 +1204,62 @@ impl ExtensionEventRow {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct ReplEventRow {
|
||||
// AppInfoBase
|
||||
app_version: String,
|
||||
major: Option<i32>,
|
||||
minor: Option<i32>,
|
||||
patch: Option<i32>,
|
||||
checksum_matched: bool,
|
||||
release_channel: String,
|
||||
os_name: String,
|
||||
os_version: String,
|
||||
|
||||
// ClientEventBase
|
||||
installation_id: Option<String>,
|
||||
session_id: Option<String>,
|
||||
is_staff: Option<bool>,
|
||||
time: i64,
|
||||
|
||||
// ReplEventRow
|
||||
kernel_language: String,
|
||||
kernel_status: String,
|
||||
repl_session_id: String,
|
||||
}
|
||||
|
||||
impl ReplEventRow {
|
||||
fn from_event(
|
||||
event: ReplEvent,
|
||||
wrapper: &EventWrapper,
|
||||
body: &EventRequestBody,
|
||||
first_event_at: chrono::DateTime<chrono::Utc>,
|
||||
checksum_matched: bool,
|
||||
) -> Self {
|
||||
let semver = body.semver();
|
||||
let time =
|
||||
first_event_at + chrono::Duration::milliseconds(wrapper.milliseconds_since_first_event);
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
checksum_matched,
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
os_name: body.os_name.clone(),
|
||||
os_version: body.os_version.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
is_staff: body.is_staff,
|
||||
time: time.timestamp_millis(),
|
||||
kernel_language: event.kernel_language,
|
||||
kernel_status: event.kernel_status,
|
||||
repl_session_id: event.repl_session_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct EditEventRow {
|
||||
// AppInfoBase
|
||||
|
||||
@@ -9,6 +9,7 @@ use axum::{
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use base64::prelude::*;
|
||||
use prometheus::{exponential_buckets, register_histogram, Histogram};
|
||||
pub use rpc::auth::random_token;
|
||||
use scrypt::{
|
||||
@@ -155,10 +156,7 @@ pub async fn create_access_token(
|
||||
/// protection.
|
||||
pub fn hash_access_token(token: &str) -> String {
|
||||
let digest = sha2::Sha256::digest(token);
|
||||
format!(
|
||||
"$sha256${}",
|
||||
base64::encode_config(digest, base64::URL_SAFE)
|
||||
)
|
||||
format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
|
||||
}
|
||||
|
||||
/// Encrypts the given access token with the given public key to avoid leaking it on the way
|
||||
@@ -402,15 +400,16 @@ mod test {
|
||||
fn previous_hash_access_token(token: &str) -> Result<String> {
|
||||
// Avoid slow hashing in debug mode.
|
||||
let params = if cfg!(debug_assertions) {
|
||||
scrypt::Params::new(1, 1, 1).unwrap()
|
||||
scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
|
||||
} else {
|
||||
scrypt::Params::new(14, 8, 1).unwrap()
|
||||
scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
|
||||
};
|
||||
|
||||
Ok(Scrypt
|
||||
.hash_password(
|
||||
.hash_password_customized(
|
||||
token.as_bytes(),
|
||||
None,
|
||||
None,
|
||||
params,
|
||||
&SaltString::generate(thread_rng()),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod ai;
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
pub mod db;
|
||||
|
||||
@@ -153,7 +153,7 @@ async fn main() -> Result<()> {
|
||||
let signal = async move {
|
||||
// todo(windows):
|
||||
// `ctrl_close` does not work well, because tokio's signal handler always returns soon,
|
||||
// but system termiates the application soon after returning CTRL+CLOSE handler.
|
||||
// but system terminates the application soon after returning CTRL+CLOSE handler.
|
||||
// So we should implement blocking handler to treat CTRL+CLOSE signal.
|
||||
let mut ctrl_break = tokio::signal::windows::ctrl_break()
|
||||
.expect("failed to listen for interrupt signal");
|
||||
|
||||
@@ -10,9 +10,9 @@ use crate::{
|
||||
ServerId, UpdatedChannelMessage, User, UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
AppState, Error, RateLimit, RateLimiter, Result,
|
||||
AppState, Config, Error, RateLimit, RateLimiter, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use anyhow::{anyhow, bail, Context as _};
|
||||
use async_tungstenite::tungstenite::{
|
||||
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
|
||||
};
|
||||
@@ -46,8 +46,8 @@ use http_client::IsahcHttpClient;
|
||||
use prometheus::{register_int_gauge, IntGauge};
|
||||
use rpc::{
|
||||
proto::{
|
||||
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
|
||||
LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
|
||||
self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
|
||||
RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
|
||||
},
|
||||
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
|
||||
};
|
||||
@@ -605,29 +605,40 @@ impl Server {
|
||||
))
|
||||
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
|
||||
.add_message_handler(update_context)
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
complete_with_language_model(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_streaming_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
move |request, response, session| {
|
||||
complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.openai_api_key.clone(),
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
app_state.config.anthropic_api_key.clone(),
|
||||
)
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
stream_complete_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
&app_state.config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
user_handler(move |request, response, session| {
|
||||
count_tokens_with_language_model(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
)
|
||||
})
|
||||
move |request, response, session| {
|
||||
let app_state = app_state.clone();
|
||||
async move {
|
||||
count_language_model_tokens(request, response, session, &app_state.config)
|
||||
.await
|
||||
}
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
user_handler(move |request, response, session| {
|
||||
@@ -1392,7 +1403,7 @@ pub async fn handle_websocket_request(
|
||||
let socket = socket
|
||||
.map_ok(to_tungstenite_message)
|
||||
.err_into()
|
||||
.with(|message| async move { Ok(to_axum_message(message)) });
|
||||
.with(|message| async move { to_axum_message(message) });
|
||||
let connection = Connection::new(Box::pin(socket));
|
||||
async move {
|
||||
server
|
||||
@@ -4514,310 +4525,172 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
|
||||
}
|
||||
|
||||
async fn complete_with_language_model(
|
||||
mut request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: Response<proto::CompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
open_ai_api_key: Option<Arc<str>>,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
anthropic_api_key: Option<Arc<str>>,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
let mut provider_and_model = request.model.split('/');
|
||||
let (provider, model) = match (
|
||||
provider_and_model.next().unwrap(),
|
||||
provider_and_model.next(),
|
||||
) {
|
||||
(provider, Some(model)) => (provider, model),
|
||||
(model, None) => {
|
||||
if model.starts_with("gpt") {
|
||||
("openai", model)
|
||||
} else if model.starts_with("gemini") {
|
||||
("google", model)
|
||||
} else if model.starts_with("claude") {
|
||||
("anthropic", model)
|
||||
} else {
|
||||
("unknown", model)
|
||||
}
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
anthropic::complete(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
let provider = provider.to_string();
|
||||
request.model = model.to_string();
|
||||
|
||||
match provider.as_str() {
|
||||
"openai" => {
|
||||
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
|
||||
complete_with_open_ai(request, response, session, api_key).await?;
|
||||
}
|
||||
"anthropic" => {
|
||||
let api_key =
|
||||
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
|
||||
complete_with_anthropic(request, response, session, api_key).await?;
|
||||
}
|
||||
"google" => {
|
||||
let api_key =
|
||||
google_ai_api_key.context("no Google AI API key configured on the server")?;
|
||||
complete_with_google_ai(request, response, session, api_key).await?;
|
||||
}
|
||||
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
|
||||
}
|
||||
response.send(proto::CompleteWithLanguageModelResponse {
|
||||
completion: serde_json::to_string(&result)?,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_open_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: UserSession,
|
||||
api_key: Arc<str>,
|
||||
async fn stream_complete_with_language_model(
|
||||
request: proto::StreamCompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let mut completion_stream = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
crate::ai::language_model_request_to_open_ai(request)?,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.context("open_ai::stream_completion request failed within collab")?;
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
while let Some(event) = completion_stream.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: event
|
||||
.choices
|
||||
.into_iter()
|
||||
.map(|choice| proto::LanguageModelChoiceDelta {
|
||||
index: choice.index,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: choice.delta.role.map(|role| match role {
|
||||
open_ai::Role::User => LanguageModelRole::LanguageModelUser,
|
||||
open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
|
||||
open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
|
||||
open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
|
||||
} as i32),
|
||||
content: choice.delta.content,
|
||||
tool_calls: choice
|
||||
.delta
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|delta| proto::ToolCallDelta {
|
||||
index: delta.index as u32,
|
||||
id: delta.id,
|
||||
variant: match delta.function {
|
||||
Some(function) => {
|
||||
let name = function.name;
|
||||
let arguments = function.arguments;
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
Some(proto::tool_call_delta::Variant::Function(
|
||||
proto::tool_call_delta::FunctionCallDelta {
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
))
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
finish_reason: choice.finish_reason,
|
||||
})
|
||||
.collect(),
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_google_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: UserSession,
|
||||
api_key: Arc<str>,
|
||||
) -> Result<()> {
|
||||
let mut stream = google_ai::stream_generate_content(
|
||||
session.http_client.clone(),
|
||||
google_ai::API_URL,
|
||||
api_key.as_ref(),
|
||||
&request.model.clone(),
|
||||
crate::ai::language_model_request_to_google_ai(request)?,
|
||||
)
|
||||
.await
|
||||
.context("google_ai::stream_generate_content request failed")?;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: event
|
||||
.candidates
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|candidate| proto::LanguageModelChoiceDelta {
|
||||
index: candidate.index as u32,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(match candidate.content.role {
|
||||
google_ai::Role::User => LanguageModelRole::LanguageModelUser,
|
||||
google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
|
||||
} as i32),
|
||||
content: Some(
|
||||
candidate
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.filter_map(|part| match part {
|
||||
google_ai::Part::TextPart(part) => Some(part.text),
|
||||
google_ai::Part::InlineDataPart(_) => None,
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
// Tool calls are not supported for Google
|
||||
tool_calls: Vec::new(),
|
||||
}),
|
||||
finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
|
||||
})
|
||||
.collect(),
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn complete_with_anthropic(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
response: StreamingResponse<proto::CompleteWithLanguageModel>,
|
||||
session: UserSession,
|
||||
api_key: Arc<str>,
|
||||
) -> Result<()> {
|
||||
let model = anthropic::Model::from_id(&request.model)?;
|
||||
|
||||
let mut system_message = String::new();
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|message| {
|
||||
match message.role() {
|
||||
LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
|
||||
role: anthropic::Role::User,
|
||||
content: message.content,
|
||||
}),
|
||||
LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
|
||||
role: anthropic::Role::Assistant,
|
||||
content: message.content,
|
||||
}),
|
||||
// Anthropic's API breaks system instructions out as a separate field rather
|
||||
// than having a system message role.
|
||||
LanguageModelRole::LanguageModelSystem => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
|
||||
None
|
||||
}
|
||||
// We don't yet support tool calls for Anthropic
|
||||
LanguageModelRole::LanguageModelTool => None,
|
||||
match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Anthropic) => {
|
||||
let api_key = config
|
||||
.anthropic_api_key
|
||||
.as_ref()
|
||||
.context("no Anthropic AI API key configured on the server")?;
|
||||
let mut chunks = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = chunks.next().await {
|
||||
let chunk = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&chunk)?,
|
||||
})?;
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut stream = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
&api_key,
|
||||
anthropic::Request {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
let event = event?;
|
||||
|
||||
match event {
|
||||
anthropic::ResponseEvent::MessageStart { message } => {
|
||||
if let Some(role) = message.role {
|
||||
if role == "assistant" {
|
||||
current_role = proto::LanguageModelRole::LanguageModelAssistant;
|
||||
} else if role == "user" {
|
||||
current_role = proto::LanguageModelRole::LanguageModelUser;
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
|
||||
match content_block {
|
||||
anthropic::ContentBlock::Text { text } => {
|
||||
if !text.is_empty() {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(current_role as i32),
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: Some(proto::LanguageModelResponseMessage {
|
||||
role: Some(current_role as i32),
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
}),
|
||||
finish_reason: None,
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
},
|
||||
anthropic::ResponseEvent::MessageDelta { delta, .. } => {
|
||||
if let Some(stop_reason) = delta.stop_reason {
|
||||
response.send(proto::LanguageModelResponse {
|
||||
choices: vec![proto::LanguageModelChoiceDelta {
|
||||
index: 0,
|
||||
delta: None,
|
||||
finish_reason: Some(stop_reason),
|
||||
}],
|
||||
})?;
|
||||
}
|
||||
}
|
||||
anthropic::ResponseEvent::ContentBlockStop { .. } => {}
|
||||
anthropic::ResponseEvent::MessageStop {} => {}
|
||||
anthropic::ResponseEvent::Ping {} => {}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::OpenAi) => {
|
||||
let api_key = config
|
||||
.openai_api_key
|
||||
.as_ref()
|
||||
.context("no OpenAI API key configured on the server")?;
|
||||
let mut events = open_ai::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
open_ai::OPEN_AI_API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
let mut events = google_ai::stream_generate_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?;
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
response.send(proto::StreamCompleteWithLanguageModelResponse {
|
||||
event: serde_json::to_string(&event)?,
|
||||
})?;
|
||||
}
|
||||
}
|
||||
None => return Err(anyhow!("unknown provider"))?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountTokensWithLanguageModelRateLimit;
|
||||
async fn count_language_model_tokens(
|
||||
request: proto::CountLanguageModelTokens,
|
||||
response: Response<proto::CountLanguageModelTokens>,
|
||||
session: Session,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
let Some(session) = session.for_user() else {
|
||||
return Err(anyhow!("user not found"))?;
|
||||
};
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountLanguageModelTokensRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
|
||||
Some(proto::LanguageModelProvider::Google) => {
|
||||
let api_key = config
|
||||
.google_ai_api_key
|
||||
.as_ref()
|
||||
.context("no Google AI API key configured on the server")?;
|
||||
google_ai::count_tokens(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
api_key,
|
||||
serde_json::from_str(&request.request)?,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
_ => return Err(anyhow!("unsupported provider"))?,
|
||||
};
|
||||
|
||||
response.send(proto::CountLanguageModelTokensResponse {
|
||||
token_count: result.total_tokens as u32,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountLanguageModelTokensRateLimit;
|
||||
|
||||
impl RateLimit for CountLanguageModelTokensRateLimit {
|
||||
fn capacity() -> usize {
|
||||
std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
|
||||
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(600) // Picked arbitrarily
|
||||
@@ -4828,45 +4701,10 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
}
|
||||
|
||||
fn db_name() -> &'static str {
|
||||
"count-tokens-with-language-model"
|
||||
"count-language-model-tokens"
|
||||
}
|
||||
}
|
||||
|
||||
async fn count_tokens_with_language_model(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
response: Response<proto::CountTokensWithLanguageModel>,
|
||||
session: UserSession,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
if !request.model.starts_with("gemini") {
|
||||
return Err(anyhow!(
|
||||
"counting tokens for model: {:?} is not supported",
|
||||
request.model
|
||||
))?;
|
||||
}
|
||||
|
||||
session
|
||||
.rate_limiter
|
||||
.check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
|
||||
.await?;
|
||||
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
let tokens_response = google_ai::count_tokens(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
&api_key,
|
||||
crate::ai::count_tokens_request_to_google_ai(request)?,
|
||||
)
|
||||
.await?;
|
||||
response.send(proto::CountTokensResponse {
|
||||
token_count: tokens_response.total_tokens as u32,
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ComputeEmbeddingsRateLimit;
|
||||
|
||||
impl RateLimit for ComputeEmbeddingsRateLimit {
|
||||
@@ -5154,8 +4992,8 @@ async fn get_private_user_info(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
|
||||
match message {
|
||||
fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
|
||||
let message = match message {
|
||||
TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
|
||||
TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
|
||||
TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
|
||||
@@ -5164,7 +5002,20 @@ fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
|
||||
code: frame.code.into(),
|
||||
reason: frame.reason,
|
||||
})),
|
||||
}
|
||||
// We should never receive a frame while reading the message, according
|
||||
// to the `tungstenite` maintainers:
|
||||
//
|
||||
// > It cannot occur when you read messages from the WebSocket, but it
|
||||
// > can be used when you want to send the raw frames (e.g. you want to
|
||||
// > send the frames to the WebSocket without composing the full message first).
|
||||
// >
|
||||
// > — https://github.com/snapview/tungstenite-rs/issues/268
|
||||
TungsteniteMessage::Frame(_) => {
|
||||
bail!("received an unexpected frame while reading the message")
|
||||
}
|
||||
};
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::reversed_empty_ranges)]
|
||||
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
|
||||
use call::{ActiveCall, ParticipantLocation};
|
||||
use client::ChannelId;
|
||||
|
||||
@@ -533,6 +533,7 @@ impl Render for MessageEditor {
|
||||
},
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: TextSize::Small.rems(cx).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
|
||||
@@ -2190,6 +2190,7 @@ impl CollabPanel {
|
||||
},
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
|
||||
@@ -26,7 +26,9 @@ anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
ui.workspace = true
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
|
||||
use gpui::{AppContext, Global, Model, ModelContext, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
||||
LanguageModelRequest,
|
||||
LanguageModelRequest, LanguageModelTool,
|
||||
};
|
||||
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||
use std::{pin::Pin, sync::Arc, task::Poll};
|
||||
use smol::{future::FutureExt, lock::{Semaphore, SemaphoreGuardArc}};
|
||||
use std::{future, pin::Pin, sync::Arc, task::Poll};
|
||||
use ui::Context;
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
@@ -27,7 +27,7 @@ pub struct LanguageModelCompletionProvider {
|
||||
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||
|
||||
pub struct LanguageModelCompletionResponse {
|
||||
pub inner: BoxStream<'static, Result<String>>,
|
||||
inner: BoxStream<'static, Result<String>>,
|
||||
_lock: SemaphoreGuardArc,
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ impl LanguageModelCompletionProvider {
|
||||
if let Some(model) = self.active_model() {
|
||||
model.count_tokens(request, cx)
|
||||
} else {
|
||||
std::future::ready(Err(anyhow!("No active model set"))).boxed()
|
||||
future::ready(Err(anyhow!("no active model"))).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,6 +183,29 @@ impl LanguageModelCompletionProvider {
|
||||
Ok(completion)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn use_tool<T: LanguageModelTool>(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> Task<Result<T>> {
|
||||
if let Some(language_model) = self.active_model() {
|
||||
cx.spawn(|cx| async move {
|
||||
let schema = schemars::schema_for!(T);
|
||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||
let request =
|
||||
language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
|
||||
let response = request.await?;
|
||||
Ok(serde_json::from_value(response)?)
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("No active model set")))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_model_telemetry_id(&self) -> Option<String> {
|
||||
self.active_model.as_ref().map(|m| m.telemetry_id())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -691,7 +691,7 @@ impl Copilot {
|
||||
{
|
||||
match event {
|
||||
language::Event::Edited => {
|
||||
let _ = registered_buffer.report_changes(&buffer, cx);
|
||||
drop(registered_buffer.report_changes(&buffer, cx));
|
||||
}
|
||||
language::Event::Saved => {
|
||||
server
|
||||
|
||||
@@ -333,7 +333,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one.|<>
|
||||
@@ -341,7 +341,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec!["completion_a", "completion_b"],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -375,7 +375,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one.|<>
|
||||
@@ -383,7 +383,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec![],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -408,7 +408,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one.|<>
|
||||
@@ -416,7 +416,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec!["completion_a", "completion_b"],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -590,7 +590,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one.|<>
|
||||
@@ -598,7 +598,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec![],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -632,7 +632,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one.|<>
|
||||
@@ -640,7 +640,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec![],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -889,7 +889,7 @@ mod tests {
|
||||
three
|
||||
"});
|
||||
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one
|
||||
@@ -897,7 +897,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec!["completion_a", "completion_b"],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -917,7 +917,7 @@ mod tests {
|
||||
});
|
||||
|
||||
cx.simulate_keystroke("o");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one
|
||||
@@ -925,7 +925,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec!["completion_a_2", "completion_b_2"],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
@@ -944,7 +944,7 @@ mod tests {
|
||||
});
|
||||
|
||||
cx.simulate_keystroke(".");
|
||||
let _ = handle_completion_request(
|
||||
drop(handle_completion_request(
|
||||
&mut cx,
|
||||
indoc! {"
|
||||
one
|
||||
@@ -952,7 +952,7 @@ mod tests {
|
||||
three
|
||||
"},
|
||||
vec!["something_else()"],
|
||||
);
|
||||
));
|
||||
handle_copilot_completion_request(
|
||||
&copilot_lsp,
|
||||
vec![crate::request::Completion {
|
||||
|
||||
@@ -109,6 +109,7 @@ pub struct DisplayMap {
|
||||
crease_map: CreaseMap,
|
||||
fold_placeholder: FoldPlaceholder,
|
||||
pub clip_at_line_ends: bool,
|
||||
pub(crate) masked: bool,
|
||||
}
|
||||
|
||||
impl DisplayMap {
|
||||
@@ -156,6 +157,7 @@ impl DisplayMap {
|
||||
text_highlights: Default::default(),
|
||||
inlay_highlights: Default::default(),
|
||||
clip_at_line_ends: false,
|
||||
masked: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +184,7 @@ impl DisplayMap {
|
||||
text_highlights: self.text_highlights.clone(),
|
||||
inlay_highlights: self.inlay_highlights.clone(),
|
||||
clip_at_line_ends: self.clip_at_line_ends,
|
||||
masked: self.masked,
|
||||
fold_placeholder: self.fold_placeholder.clone(),
|
||||
}
|
||||
}
|
||||
@@ -499,6 +502,7 @@ pub struct DisplaySnapshot {
|
||||
text_highlights: TextHighlights,
|
||||
inlay_highlights: InlayHighlights,
|
||||
clip_at_line_ends: bool,
|
||||
masked: bool,
|
||||
pub(crate) fold_placeholder: FoldPlaceholder,
|
||||
}
|
||||
|
||||
@@ -650,6 +654,7 @@ impl DisplaySnapshot {
|
||||
.chunks(
|
||||
display_row.0..self.max_point().row().next_row().0,
|
||||
false,
|
||||
self.masked,
|
||||
Highlights::default(),
|
||||
)
|
||||
.map(|h| h.text)
|
||||
@@ -657,9 +662,9 @@ impl DisplaySnapshot {
|
||||
|
||||
/// Returns text chunks starting at the end of the given display row in reverse until the start of the file
|
||||
pub fn reverse_text_chunks(&self, display_row: DisplayRow) -> impl Iterator<Item = &str> {
|
||||
(0..=display_row.0).rev().flat_map(|row| {
|
||||
(0..=display_row.0).rev().flat_map(move |row| {
|
||||
self.block_snapshot
|
||||
.chunks(row..row + 1, false, Highlights::default())
|
||||
.chunks(row..row + 1, false, self.masked, Highlights::default())
|
||||
.map(|h| h.text)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
@@ -676,6 +681,7 @@ impl DisplaySnapshot {
|
||||
self.block_snapshot.chunks(
|
||||
display_rows.start.0..display_rows.end.0,
|
||||
language_aware,
|
||||
self.masked,
|
||||
Highlights {
|
||||
text_highlights: Some(&self.text_highlights),
|
||||
inlay_highlights: Some(&self.inlay_highlights),
|
||||
|
||||
@@ -23,6 +23,7 @@ use text::Edit;
|
||||
use ui::ElementId;
|
||||
|
||||
const NEWLINES: &[u8] = &[b'\n'; u8::MAX as usize];
|
||||
const BULLETS: &str = "********************************************************************************************************************************";
|
||||
|
||||
/// Tracks custom blocks such as diagnostics that should be displayed within buffer.
|
||||
///
|
||||
@@ -285,6 +286,7 @@ pub struct BlockChunks<'a> {
|
||||
input_chunk: Chunk<'a>,
|
||||
output_row: u32,
|
||||
max_output_row: u32,
|
||||
masked: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -893,6 +895,7 @@ impl BlockSnapshot {
|
||||
self.chunks(
|
||||
0..self.transforms.summary().output_rows,
|
||||
false,
|
||||
false,
|
||||
Highlights::default(),
|
||||
)
|
||||
.map(|chunk| chunk.text)
|
||||
@@ -903,6 +906,7 @@ impl BlockSnapshot {
|
||||
&'a self,
|
||||
rows: Range<u32>,
|
||||
language_aware: bool,
|
||||
masked: bool,
|
||||
highlights: Highlights<'a>,
|
||||
) -> BlockChunks<'a> {
|
||||
let max_output_row = cmp::min(rows.end, self.transforms.summary().output_rows);
|
||||
@@ -941,6 +945,7 @@ impl BlockSnapshot {
|
||||
transforms: cursor,
|
||||
output_row: rows.start,
|
||||
max_output_row,
|
||||
masked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1229,12 +1234,20 @@ impl<'a> Iterator for BlockChunks<'a> {
|
||||
let (prefix_rows, prefix_bytes) =
|
||||
offset_for_row(self.input_chunk.text, transform_end - self.output_row);
|
||||
self.output_row += prefix_rows;
|
||||
let (prefix, suffix) = self.input_chunk.text.split_at(prefix_bytes);
|
||||
let (mut prefix, suffix) = self.input_chunk.text.split_at(prefix_bytes);
|
||||
self.input_chunk.text = suffix;
|
||||
if self.output_row == transform_end {
|
||||
self.transforms.next(&());
|
||||
}
|
||||
|
||||
if self.masked {
|
||||
// Not great for multibyte text because to keep cursor math correct we
|
||||
// need to have the same number of bytes in the input as output.
|
||||
let chars = prefix.chars().count();
|
||||
let bullet_len = chars;
|
||||
prefix = &BULLETS[..bullet_len];
|
||||
}
|
||||
|
||||
Some(Chunk {
|
||||
text: prefix,
|
||||
..self.input_chunk.clone()
|
||||
@@ -2048,6 +2061,7 @@ mod tests {
|
||||
.chunks(
|
||||
start_row as u32..blocks_snapshot.max_point().row + 1,
|
||||
false,
|
||||
false,
|
||||
Highlights::default(),
|
||||
)
|
||||
.map(|chunk| chunk.text)
|
||||
|
||||
@@ -408,6 +408,7 @@ impl EditorActionId {
|
||||
type BackgroundHighlight = (fn(&ThemeColors) -> Hsla, Arc<[Range<Anchor>]>);
|
||||
type GutterHighlight = (fn(&AppContext) -> Hsla, Arc<[Range<Anchor>]>);
|
||||
|
||||
#[derive(Default)]
|
||||
struct ScrollbarMarkerState {
|
||||
scrollbar_size: Size<Pixels>,
|
||||
dirty: bool,
|
||||
@@ -421,17 +422,6 @@ impl ScrollbarMarkerState {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ScrollbarMarkerState {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scrollbar_size: Size::default(),
|
||||
dirty: false,
|
||||
markers: Arc::from([]),
|
||||
pending_refresh: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct RunnableTasks {
|
||||
templates: Vec<(TaskSourceKind, TaskTemplate)>,
|
||||
@@ -490,7 +480,6 @@ pub struct Editor {
|
||||
mode: EditorMode,
|
||||
show_breadcrumbs: bool,
|
||||
show_gutter: bool,
|
||||
redact_all: bool,
|
||||
show_line_numbers: Option<bool>,
|
||||
show_git_diff_gutter: Option<bool>,
|
||||
show_code_actions: Option<bool>,
|
||||
@@ -1813,7 +1802,6 @@ impl Editor {
|
||||
show_code_actions: None,
|
||||
show_runnables: None,
|
||||
show_wrap_guides: None,
|
||||
redact_all: false,
|
||||
show_indent_guides,
|
||||
placeholder_text: None,
|
||||
highlight_order: 0,
|
||||
@@ -5730,7 +5718,7 @@ impl Editor {
|
||||
|
||||
self.transact(cx, |this, cx| {
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
let empty_str: Arc<str> = "".into();
|
||||
let empty_str: Arc<str> = Arc::default();
|
||||
buffer.edit(
|
||||
deletion_ranges
|
||||
.into_iter()
|
||||
@@ -5796,7 +5784,7 @@ impl Editor {
|
||||
|
||||
self.transact(cx, |this, cx| {
|
||||
let buffer = this.buffer.update(cx, |buffer, cx| {
|
||||
let empty_str: Arc<str> = "".into();
|
||||
let empty_str: Arc<str> = Arc::default();
|
||||
buffer.edit(
|
||||
edit_ranges
|
||||
.into_iter()
|
||||
@@ -8097,7 +8085,7 @@ impl Editor {
|
||||
let mut selection_edit_ranges = Vec::new();
|
||||
let mut last_toggled_row = None;
|
||||
let snapshot = this.buffer.read(cx).read(cx);
|
||||
let empty_str: Arc<str> = "".into();
|
||||
let empty_str: Arc<str> = Arc::default();
|
||||
let mut suffixes_inserted = Vec::new();
|
||||
|
||||
fn comment_prefix_range(
|
||||
@@ -10430,9 +10418,11 @@ impl Editor {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn set_redact_all(&mut self, redact_all: bool, cx: &mut ViewContext<Self>) {
|
||||
self.redact_all = redact_all;
|
||||
cx.notify();
|
||||
pub fn set_masked(&mut self, masked: bool, cx: &mut ViewContext<Self>) {
|
||||
if self.display_map.read(cx).masked != masked {
|
||||
self.display_map.update(cx, |map, _| map.masked = masked);
|
||||
}
|
||||
cx.notify()
|
||||
}
|
||||
|
||||
pub fn set_show_wrap_guides(&mut self, show_wrap_guides: bool, cx: &mut ViewContext<Self>) {
|
||||
@@ -11118,10 +11108,6 @@ impl Editor {
|
||||
display_snapshot: &DisplaySnapshot,
|
||||
cx: &WindowContext,
|
||||
) -> Vec<Range<DisplayPoint>> {
|
||||
if self.redact_all {
|
||||
return vec![DisplayPoint::zero()..display_snapshot.max_point()];
|
||||
}
|
||||
|
||||
display_snapshot
|
||||
.buffer_snapshot
|
||||
.redacted_ranges(search_range, |file| {
|
||||
@@ -12444,6 +12430,7 @@ impl Render for Editor {
|
||||
color: cx.theme().colors().editor_foreground,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(settings.buffer_line_height.value()),
|
||||
@@ -12453,6 +12440,7 @@ impl Render for Editor {
|
||||
color: cx.theme().colors().editor_foreground,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_size: settings.buffer_font_size(cx).into(),
|
||||
font_weight: settings.buffer_font.weight,
|
||||
line_height: relative(settings.buffer_line_height.value()),
|
||||
|
||||
@@ -305,7 +305,7 @@ pub struct ScrollbarContent {
|
||||
}
|
||||
|
||||
/// Gutter related settings
|
||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
pub struct GutterContent {
|
||||
/// Whether to show line numbers in the gutter.
|
||||
///
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use gpui::{AppContext, FontWeight};
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{AppContext, FontFeatures, FontWeight};
|
||||
use project::project_settings::{InlineBlameSettings, ProjectSettings};
|
||||
use settings::{EditableSettingControl, Settings};
|
||||
use theme::{FontFamilyCache, ThemeSettings};
|
||||
@@ -7,6 +9,8 @@ use ui::{
|
||||
SettingsGroup,
|
||||
};
|
||||
|
||||
use crate::EditorSettings;
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct EditorSettingsControls {}
|
||||
|
||||
@@ -28,9 +32,19 @@ impl RenderOnce for EditorSettingsControls {
|
||||
.child(BufferFontFamilyControl)
|
||||
.child(BufferFontWeightControl),
|
||||
)
|
||||
.child(BufferFontSizeControl),
|
||||
.child(BufferFontSizeControl)
|
||||
.child(BufferFontLigaturesControl),
|
||||
)
|
||||
.child(SettingsGroup::new("Editor").child(InlineGitBlameControl))
|
||||
.child(
|
||||
SettingsGroup::new("Gutter").child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(LineNumbersControl)
|
||||
.child(RelativeLineNumbersControl),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,6 +140,7 @@ impl RenderOnce for BufferFontSizeControl {
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::FontSize))
|
||||
.child(NumericStepper::new(
|
||||
"buffer-font-size",
|
||||
value.to_string(),
|
||||
move |_, cx| {
|
||||
Self::write(value - px(1.), cx);
|
||||
@@ -190,6 +205,76 @@ impl RenderOnce for BufferFontWeightControl {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
struct BufferFontLigaturesControl;
|
||||
|
||||
impl EditableSettingControl for BufferFontLigaturesControl {
|
||||
type Value = bool;
|
||||
type Settings = ThemeSettings;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Buffer Font Ligatures".into()
|
||||
}
|
||||
|
||||
fn read(cx: &AppContext) -> Self::Value {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
settings
|
||||
.buffer_font
|
||||
.features
|
||||
.is_calt_enabled()
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
fn apply(
|
||||
settings: &mut <Self::Settings as Settings>::FileContent,
|
||||
value: Self::Value,
|
||||
_cx: &AppContext,
|
||||
) {
|
||||
let value = if value { 1 } else { 0 };
|
||||
|
||||
let mut features = settings
|
||||
.buffer_font_features
|
||||
.as_ref()
|
||||
.map(|features| {
|
||||
features
|
||||
.tag_value_list()
|
||||
.into_iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(calt_index) = features.iter().position(|(tag, _)| tag == "calt") {
|
||||
features[calt_index].1 = value;
|
||||
} else {
|
||||
features.push(("calt".into(), value));
|
||||
}
|
||||
|
||||
settings.buffer_font_features = Some(FontFeatures(Arc::new(features)));
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for BufferFontLigaturesControl {
|
||||
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
let value = Self::read(cx);
|
||||
|
||||
CheckboxWithLabel::new(
|
||||
"buffer-font-ligatures",
|
||||
Label::new(self.name()),
|
||||
value.into(),
|
||||
|selection, cx| {
|
||||
Self::write(
|
||||
match selection {
|
||||
Selection::Selected => true,
|
||||
Selection::Unselected | Selection::Indeterminate => false,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
struct InlineGitBlameControl;
|
||||
|
||||
@@ -242,3 +327,102 @@ impl RenderOnce for InlineGitBlameControl {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
struct LineNumbersControl;
|
||||
|
||||
impl EditableSettingControl for LineNumbersControl {
|
||||
type Value = bool;
|
||||
type Settings = EditorSettings;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Line Numbers".into()
|
||||
}
|
||||
|
||||
fn read(cx: &AppContext) -> Self::Value {
|
||||
let settings = EditorSettings::get_global(cx);
|
||||
settings.gutter.line_numbers
|
||||
}
|
||||
|
||||
fn apply(
|
||||
settings: &mut <Self::Settings as Settings>::FileContent,
|
||||
value: Self::Value,
|
||||
_cx: &AppContext,
|
||||
) {
|
||||
if let Some(gutter) = settings.gutter.as_mut() {
|
||||
gutter.line_numbers = Some(value);
|
||||
} else {
|
||||
settings.gutter = Some(crate::editor_settings::GutterContent {
|
||||
line_numbers: Some(value),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for LineNumbersControl {
|
||||
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
let value = Self::read(cx);
|
||||
|
||||
CheckboxWithLabel::new(
|
||||
"line-numbers",
|
||||
Label::new(self.name()),
|
||||
value.into(),
|
||||
|selection, cx| {
|
||||
Self::write(
|
||||
match selection {
|
||||
Selection::Selected => true,
|
||||
Selection::Unselected | Selection::Indeterminate => false,
|
||||
},
|
||||
cx,
|
||||
);
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
struct RelativeLineNumbersControl;
|
||||
|
||||
impl EditableSettingControl for RelativeLineNumbersControl {
|
||||
type Value = bool;
|
||||
type Settings = EditorSettings;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Relative Line Numbers".into()
|
||||
}
|
||||
|
||||
fn read(cx: &AppContext) -> Self::Value {
|
||||
let settings = EditorSettings::get_global(cx);
|
||||
settings.relative_line_numbers
|
||||
}
|
||||
|
||||
fn apply(
|
||||
settings: &mut <Self::Settings as Settings>::FileContent,
|
||||
value: Self::Value,
|
||||
_cx: &AppContext,
|
||||
) {
|
||||
settings.relative_line_numbers = Some(value);
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for RelativeLineNumbersControl {
|
||||
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
let value = Self::read(cx);
|
||||
|
||||
DropdownMenu::new(
|
||||
"relative-line-numbers",
|
||||
if value { "Relative" } else { "Ascending" },
|
||||
ContextMenu::build(cx, |menu, _cx| {
|
||||
menu.custom_entry(
|
||||
|_cx| Label::new("Ascending").into_any_element(),
|
||||
move |cx| Self::write(false, cx),
|
||||
)
|
||||
.custom_entry(
|
||||
|_cx| Label::new("Relative").into_any_element(),
|
||||
move |cx| Self::write(true, cx),
|
||||
)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10105,7 +10105,7 @@ struct Row8;
|
||||
struct Row9;
|
||||
struct Row10;"#};
|
||||
|
||||
// Deletion hunks trigger with carets on ajacent rows, so carets and selections have to stay farther to avoid the revert
|
||||
// Deletion hunks trigger with carets on adjacent rows, so carets and selections have to stay farther to avoid the revert
|
||||
assert_hunk_revert(
|
||||
indoc! {r#"struct Row;
|
||||
struct Row2;
|
||||
|
||||
@@ -59,6 +59,7 @@ use std::{
|
||||
fmt::{self, Write},
|
||||
iter, mem,
|
||||
ops::{Deref, Range},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
};
|
||||
use sum_tree::Bias;
|
||||
@@ -1969,6 +1970,7 @@ impl EditorElement {
|
||||
max_width: text_hitbox.size.width.max(*scroll_width),
|
||||
editor_style: &self.style,
|
||||
}))
|
||||
.cursor(CursorStyle::Arrow)
|
||||
.on_mouse_down(MouseButton::Left, |_, cx| cx.stop_propagation())
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -4105,11 +4107,11 @@ fn prepaint_gutter_button(
|
||||
);
|
||||
let indicator_size = button.layout_as_root(available_space, cx);
|
||||
|
||||
let blame_offset = gutter_dimensions.git_blame_entries_width;
|
||||
let gutter_offset = rows_with_hunk_bounds
|
||||
let blame_width = gutter_dimensions.git_blame_entries_width;
|
||||
let gutter_width = rows_with_hunk_bounds
|
||||
.get(&row)
|
||||
.map(|bounds| bounds.origin.x + bounds.size.width);
|
||||
let left_offset = blame_offset.max(gutter_offset).unwrap_or(Pixels::ZERO);
|
||||
.map(|bounds| bounds.size.width);
|
||||
let left_offset = blame_width.max(gutter_width).unwrap_or_default();
|
||||
|
||||
let mut x = left_offset;
|
||||
let available_width = gutter_dimensions.margin + gutter_dimensions.left_padding
|
||||
@@ -5492,7 +5494,7 @@ impl Element for EditorElement {
|
||||
|
||||
EditorLayout {
|
||||
mode: snapshot.mode,
|
||||
position_map: Arc::new(PositionMap {
|
||||
position_map: Rc::new(PositionMap {
|
||||
size: bounds.size,
|
||||
scroll_pixel_position,
|
||||
scroll_max,
|
||||
@@ -5642,7 +5644,7 @@ impl IntoElement for EditorElement {
|
||||
}
|
||||
|
||||
pub struct EditorLayout {
|
||||
position_map: Arc<PositionMap>,
|
||||
position_map: Rc<PositionMap>,
|
||||
hitbox: Hitbox,
|
||||
text_hitbox: Hitbox,
|
||||
gutter_hitbox: Hitbox,
|
||||
|
||||
@@ -13,8 +13,8 @@ use multi_buffer::{
|
||||
use settings::SettingsStore;
|
||||
use text::{BufferId, Point};
|
||||
use ui::{
|
||||
div, h_flex, v_flex, ActiveTheme, Context as _, ContextMenu, InteractiveElement, IntoElement,
|
||||
ParentElement, Pixels, Styled, ViewContext, VisualContext,
|
||||
div, h_flex, rems, v_flex, ActiveTheme, Context as _, ContextMenu, InteractiveElement,
|
||||
IntoElement, ParentElement, Pixels, Styled, ViewContext, VisualContext,
|
||||
};
|
||||
use util::{debug_panic, RangeExt};
|
||||
|
||||
@@ -484,7 +484,10 @@ impl Editor {
|
||||
.child(
|
||||
h_flex()
|
||||
.id("gutter hunk")
|
||||
.pl(hunk_bounds.origin.x)
|
||||
.pl(gutter_dimensions.margin
|
||||
+ gutter_dimensions
|
||||
.git_blame_entries_width
|
||||
.unwrap_or_default())
|
||||
.max_w(hunk_bounds.size.width)
|
||||
.min_w(hunk_bounds.size.width)
|
||||
.size_full()
|
||||
@@ -512,7 +515,7 @@ impl Editor {
|
||||
.child(
|
||||
v_flex()
|
||||
.size_full()
|
||||
.pt(ui::rems(0.25))
|
||||
.pt(rems(0.25))
|
||||
.justify_start()
|
||||
.child(close_button),
|
||||
),
|
||||
|
||||
@@ -44,7 +44,7 @@ impl SelectionsCollection {
|
||||
buffer,
|
||||
next_selection_id: 1,
|
||||
line_mode: false,
|
||||
disjoint: Arc::from([]),
|
||||
disjoint: Arc::default(),
|
||||
pending: Some(PendingSelection {
|
||||
selection: Selection {
|
||||
id: 0,
|
||||
@@ -398,7 +398,7 @@ impl<'a> MutableSelectionsCollection<'a> {
|
||||
}
|
||||
|
||||
pub fn clear_disjoint(&mut self) {
|
||||
self.collection.disjoint = Arc::from([]);
|
||||
self.collection.disjoint = Arc::default();
|
||||
}
|
||||
|
||||
pub fn delete(&mut self, selection_id: usize) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::Editor;
|
||||
use gpui::{Task as AsyncTask, WindowContext};
|
||||
use project::Location;
|
||||
use task::{TaskContext, TaskVariables, VariableName};
|
||||
use text::{Point, ToOffset, ToPoint};
|
||||
use text::{ToOffset, ToPoint};
|
||||
use workspace::Workspace;
|
||||
|
||||
fn task_context_with_editor(
|
||||
@@ -14,11 +14,7 @@ fn task_context_with_editor(
|
||||
return AsyncTask::ready(None);
|
||||
};
|
||||
let (selection, buffer, editor_snapshot) = {
|
||||
let mut selection = editor.selections.newest::<Point>(cx);
|
||||
if editor.selections.line_mode {
|
||||
selection.start = Point::new(selection.start.row, 0);
|
||||
selection.end = Point::new(selection.end.row + 1, 0);
|
||||
}
|
||||
let selection = editor.selections.newest_adjusted(cx);
|
||||
let Some((buffer, _, _)) = editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
|
||||
@@ -27,6 +27,7 @@ pub fn marked_display_snapshot(
|
||||
let font = Font {
|
||||
family: "Zed Plex Mono".into(),
|
||||
features: FontFeatures::default(),
|
||||
fallbacks: None,
|
||||
weight: FontWeight::default(),
|
||||
style: FontStyle::default(),
|
||||
};
|
||||
|
||||
@@ -327,7 +327,7 @@ impl EditorTestContext {
|
||||
.background_highlights
|
||||
.get(&TypeId::of::<Tag>())
|
||||
.map(|h| h.1.clone())
|
||||
.unwrap_or_else(|| Arc::from([]))
|
||||
.unwrap_or_else(|| Arc::default())
|
||||
.into_iter()
|
||||
.map(|range| range.to_offset(&snapshot.buffer_snapshot))
|
||||
.collect()
|
||||
|
||||
@@ -21,7 +21,6 @@ assistant_slash_command.workspace = true
|
||||
async-compression.workspace = true
|
||||
async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
cap-std.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
|
||||
@@ -363,6 +363,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
},
|
||||
);
|
||||
|
||||
#[allow(clippy::let_underscore_future)]
|
||||
let _ = store.update(cx, |store, cx| store.reload(None, cx));
|
||||
|
||||
cx.executor().advance_clock(RELOAD_DEBOUNCE_DURATION);
|
||||
|
||||
@@ -159,29 +159,25 @@ impl WasmHost {
|
||||
}
|
||||
|
||||
async fn build_wasi_ctx(&self, manifest: &Arc<ExtensionManifest>) -> Result<wasi::WasiCtx> {
|
||||
use cap_std::{ambient_authority, fs::Dir};
|
||||
|
||||
let extension_work_dir = self.work_dir.join(manifest.id.as_ref());
|
||||
self.fs
|
||||
.create_dir(&extension_work_dir)
|
||||
.await
|
||||
.context("failed to create extension work dir")?;
|
||||
|
||||
let work_dir_preopen = Dir::open_ambient_dir(&extension_work_dir, ambient_authority())
|
||||
.context("failed to preopen extension work directory")?;
|
||||
let current_dir_preopen = work_dir_preopen
|
||||
.try_clone()
|
||||
.context("failed to preopen extension current directory")?;
|
||||
let extension_work_dir = extension_work_dir.to_string_lossy();
|
||||
|
||||
let perms = wasi::FilePerms::all();
|
||||
let file_perms = wasi::FilePerms::all();
|
||||
let dir_perms = wasi::DirPerms::all();
|
||||
|
||||
Ok(wasi::WasiCtxBuilder::new()
|
||||
.inherit_stdio()
|
||||
.preopened_dir(current_dir_preopen, dir_perms, perms, ".")
|
||||
.preopened_dir(work_dir_preopen, dir_perms, perms, &extension_work_dir)
|
||||
.env("PWD", &extension_work_dir)
|
||||
.preopened_dir(&extension_work_dir, ".", dir_perms, file_perms)?
|
||||
.preopened_dir(
|
||||
&extension_work_dir,
|
||||
&extension_work_dir.to_string_lossy(),
|
||||
dir_perms,
|
||||
file_perms,
|
||||
)?
|
||||
.env("PWD", &extension_work_dir.to_string_lossy())
|
||||
.env("RUST_BACKTRACE", "full")
|
||||
.build())
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ pub fn new_linker(
|
||||
f: impl Fn(&mut Linker<WasmState>, fn(&mut WasmState) -> &mut WasmState) -> Result<()>,
|
||||
) -> Linker<WasmState> {
|
||||
let mut linker = Linker::new(&wasm_engine());
|
||||
wasmtime_wasi::command::add_to_linker(&mut linker).unwrap();
|
||||
wasmtime_wasi::add_to_linker_async(&mut linker).unwrap();
|
||||
f(&mut linker, wasi_view).unwrap();
|
||||
linker
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ pub const MIN_VERSION: SemanticVersion = SemanticVersion::new(0, 0, 1);
|
||||
|
||||
wasmtime::component::bindgen!({
|
||||
async: true,
|
||||
trappable_imports: true,
|
||||
path: "../extension_api/wit/since_v0.0.1",
|
||||
with: {
|
||||
"worktree": ExtensionWorktree,
|
||||
|
||||
@@ -11,6 +11,7 @@ pub const MIN_VERSION: SemanticVersion = SemanticVersion::new(0, 0, 4);
|
||||
|
||||
wasmtime::component::bindgen!({
|
||||
async: true,
|
||||
trappable_imports: true,
|
||||
path: "../extension_api/wit/since_v0.0.4",
|
||||
with: {
|
||||
"worktree": ExtensionWorktree,
|
||||
|
||||
@@ -12,6 +12,7 @@ pub const MAX_VERSION: SemanticVersion = SemanticVersion::new(0, 0, 6);
|
||||
|
||||
wasmtime::component::bindgen!({
|
||||
async: true,
|
||||
trappable_imports: true,
|
||||
path: "../extension_api/wit/since_v0.0.6",
|
||||
with: {
|
||||
"worktree": ExtensionWorktree,
|
||||
|
||||
@@ -26,6 +26,7 @@ pub const MAX_VERSION: SemanticVersion = SemanticVersion::new(0, 0, 7);
|
||||
|
||||
wasmtime::component::bindgen!({
|
||||
async: true,
|
||||
trappable_imports: true,
|
||||
path: "../extension_api/wit/since_v0.0.7",
|
||||
with: {
|
||||
"worktree": ExtensionWorktree,
|
||||
|
||||
@@ -816,6 +816,7 @@ impl ExtensionsPage {
|
||||
},
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
line_height: relative(1.3),
|
||||
|
||||
@@ -998,7 +998,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("b0.5")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1006,7 +1006,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("c1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1014,7 +1014,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("a1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1022,7 +1022,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("a0.5")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1030,7 +1030,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("b1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
];
|
||||
@@ -1044,7 +1044,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("a1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1052,7 +1052,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("b1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1060,7 +1060,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("c1.0")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1068,7 +1068,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("a0.5")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
ProjectPanelOrdMatch(PathMatch {
|
||||
@@ -1076,7 +1076,7 @@ mod tests {
|
||||
positions: Vec::new(),
|
||||
worktree_id: 0,
|
||||
path: Arc::from(Path::new("b0.5")),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: 0,
|
||||
}),
|
||||
]
|
||||
|
||||
@@ -404,7 +404,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_match_multibyte_path_entries() {
|
||||
let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"];
|
||||
let paths = vec![
|
||||
"aαbβ/cγdδ",
|
||||
"αβγδ/bcde",
|
||||
"c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f",
|
||||
"/d/🆒/h",
|
||||
];
|
||||
assert_eq!("1️⃣".len(), 7);
|
||||
assert_eq!(
|
||||
match_single_path_query("bcd", false, &paths),
|
||||
|
||||
@@ -120,7 +120,7 @@ pub fn match_fixed_path_set(
|
||||
worktree_id,
|
||||
positions: Vec::new(),
|
||||
path: Arc::from(candidate.path),
|
||||
path_prefix: Arc::from(""),
|
||||
path_prefix: Arc::default(),
|
||||
distance_to_relative_ancestor: usize::MAX,
|
||||
},
|
||||
);
|
||||
|
||||
@@ -5,12 +5,20 @@ edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/google_ai.rs"
|
||||
|
||||
[features]
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::HttpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
client: Arc<dyn HttpClient>,
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
request: GenerateContentRequest,
|
||||
mut request: GenerateContentRequest,
|
||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||
let uri = format!(
|
||||
"{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
|
||||
api_url, api_key
|
||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
||||
model = request.model
|
||||
);
|
||||
request.model.clear();
|
||||
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
@@ -52,8 +50,8 @@ pub async fn stream_generate_content(
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn count_tokens<T: HttpClient>(
|
||||
client: &T,
|
||||
pub async fn count_tokens(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CountTokensRequest,
|
||||
@@ -91,22 +89,24 @@ pub enum Task {
|
||||
BatchEmbedContents,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
pub model: String,
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponse {
|
||||
pub candidates: Option<Vec<GenerateContentCandidate>>,
|
||||
pub prompt_feedback: Option<PromptFeedback>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentCandidate {
|
||||
pub index: usize,
|
||||
@@ -157,7 +157,7 @@ pub struct GenerativeContentBlob {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationSource {
|
||||
pub start_index: Option<usize>,
|
||||
@@ -166,13 +166,13 @@ pub struct CitationSource {
|
||||
pub license: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationMetadata {
|
||||
pub citation_sources: Vec<CitationSource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptFeedback {
|
||||
pub block_reason: Option<String>,
|
||||
@@ -180,7 +180,7 @@ pub struct PromptFeedback {
|
||||
pub block_reason_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub candidate_count: Option<usize>,
|
||||
@@ -191,7 +191,7 @@ pub struct GenerationConfig {
|
||||
pub top_k: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
pub category: HarmCategory,
|
||||
@@ -224,7 +224,7 @@ pub enum HarmCategory {
|
||||
DangerousContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockThreshold {
|
||||
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
@@ -238,7 +238,7 @@ pub enum HarmBlockThreshold {
|
||||
BlockNone,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum HarmProbability {
|
||||
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
|
||||
@@ -249,21 +249,85 @@ pub enum HarmProbability {
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: HarmCategory,
|
||||
pub probability: HarmProbability,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Vec<Content>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
Gemini15Flash,
|
||||
#[serde(rename = "custom")]
|
||||
Custom { name: String, max_tokens: usize },
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "gemini-1.5-pro",
|
||||
Model::Gemini15Flash => "gemini-1.5-flash",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Model::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Model::Gemini15Pro => 2_000_000,
|
||||
Model::Gemini15Flash => 1_000_000,
|
||||
Model::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Model {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.id())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
events: impl Stream<Item = Result<GenerateContentResponse>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
events.filter_map(|event| async move {
|
||||
match event {
|
||||
Ok(event) => event.candidates.and_then(|candidates| {
|
||||
candidates.into_iter().next().and_then(|candidate| {
|
||||
candidate.content.parts.into_iter().next().and_then(|part| {
|
||||
if let Part::TextPart(TextPart { text }) = part {
|
||||
Some(Ok(text))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
}),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ thiserror.workspace = true
|
||||
time.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
waker-fn = "1.1.0"
|
||||
waker-fn = "1.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
backtrace = "0.3"
|
||||
@@ -93,6 +93,7 @@ cbindgen = { version = "0.26.0", default-features = false }
|
||||
block = "0.1"
|
||||
cocoa.workspace = true
|
||||
core-foundation.workspace = true
|
||||
core-foundation-sys = "0.8"
|
||||
core-graphics = "0.23"
|
||||
core-text = "20.1"
|
||||
foreign-types = "0.5"
|
||||
@@ -150,7 +151,7 @@ x11-clipboard = "0.9.2"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows.workspace = true
|
||||
windows-core = "0.57"
|
||||
windows-core = "0.58"
|
||||
|
||||
[[example]]
|
||||
name = "hello_world"
|
||||
|
||||
50
crates/gpui/examples/gif_viewer.rs
Normal file
50
crates/gpui/examples/gif_viewer.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use gpui::{
|
||||
div, img, prelude::*, App, AppContext, ImageSource, Render, ViewContext, WindowOptions,
|
||||
};
|
||||
use std::path::PathBuf;
|
||||
|
||||
struct GifViewer {
|
||||
gif_path: PathBuf,
|
||||
}
|
||||
|
||||
impl GifViewer {
|
||||
fn new(gif_path: PathBuf) -> Self {
|
||||
Self { gif_path }
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for GifViewer {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
div().size_full().child(
|
||||
img(ImageSource::File(self.gif_path.clone().into()))
|
||||
.size_full()
|
||||
.object_fit(gpui::ObjectFit::Contain)
|
||||
.id("gif"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
App::new().run(|cx: &mut AppContext| {
|
||||
let cwd = std::env::current_dir().expect("Failed to get current working directory");
|
||||
let gif_path = cwd.join("crates/gpui/examples/image/black-cat-typing.gif");
|
||||
|
||||
if !gif_path.exists() {
|
||||
eprintln!("Image file not found at {:?}", gif_path);
|
||||
eprintln!("Make sure you're running this example from the root of the gpui crate");
|
||||
cx.quit();
|
||||
return;
|
||||
}
|
||||
|
||||
cx.open_window(
|
||||
WindowOptions {
|
||||
focus: true,
|
||||
..Default::default()
|
||||
},
|
||||
|cx| cx.new_view(|_cx| GifViewer::new(gif_path)),
|
||||
)
|
||||
.unwrap();
|
||||
cx.activate(true);
|
||||
});
|
||||
}
|
||||
BIN
crates/gpui/examples/image/black-cat-typing.gif
Normal file
BIN
crates/gpui/examples/image/black-cat-typing.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.3 MiB |
@@ -1,6 +1,7 @@
|
||||
use crate::{size, DevicePixels, Result, SharedString, Size};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use image::RgbaImage;
|
||||
use image::{Delay, Frame};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
fmt,
|
||||
@@ -34,43 +35,54 @@ pub struct ImageId(usize);
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
pub(crate) struct RenderImageParams {
|
||||
pub(crate) image_id: ImageId,
|
||||
pub(crate) frame_index: usize,
|
||||
}
|
||||
|
||||
/// A cached and processed image.
|
||||
pub struct ImageData {
|
||||
/// The ID associated with this image
|
||||
pub id: ImageId,
|
||||
data: RgbaImage,
|
||||
data: SmallVec<[Frame; 1]>,
|
||||
}
|
||||
|
||||
impl ImageData {
|
||||
/// Create a new image from the given data.
|
||||
pub fn new(data: RgbaImage) -> Self {
|
||||
pub fn new(data: impl Into<SmallVec<[Frame; 1]>>) -> Self {
|
||||
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
Self {
|
||||
id: ImageId(NEXT_ID.fetch_add(1, SeqCst)),
|
||||
data,
|
||||
data: data.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this image into a byte slice.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
&self.data
|
||||
pub fn as_bytes(&self, frame_index: usize) -> &[u8] {
|
||||
&self.data[frame_index].buffer()
|
||||
}
|
||||
|
||||
/// Get the size of this image, in pixels
|
||||
pub fn size(&self) -> Size<DevicePixels> {
|
||||
let (width, height) = self.data.dimensions();
|
||||
/// Get the size of this image, in pixels.
|
||||
pub fn size(&self, frame_index: usize) -> Size<DevicePixels> {
|
||||
let (width, height) = self.data[frame_index].buffer().dimensions();
|
||||
size(width.into(), height.into())
|
||||
}
|
||||
|
||||
/// Get the delay of this frame from the previous
|
||||
pub fn delay(&self, frame_index: usize) -> Delay {
|
||||
self.data[frame_index].delay()
|
||||
}
|
||||
|
||||
/// Get the number of frames for this image.
|
||||
pub fn frame_count(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ImageData {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ImageData")
|
||||
.field("id", &self.id)
|
||||
.field("size", &self.data.dimensions())
|
||||
.field("size", &self.size(0))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,14 +323,14 @@ impl Interactivity {
|
||||
pub fn on_boxed_action(
|
||||
&mut self,
|
||||
action: &dyn Action,
|
||||
listener: impl Fn(&Box<dyn Action>, &mut WindowContext) + 'static,
|
||||
listener: impl Fn(&dyn Action, &mut WindowContext) + 'static,
|
||||
) {
|
||||
let action = action.boxed_clone();
|
||||
self.action_listeners.push((
|
||||
(*action).type_id(),
|
||||
Box::new(move |_, phase, cx| {
|
||||
if phase == DispatchPhase::Bubble {
|
||||
(listener)(&action, cx)
|
||||
(listener)(&*action, cx)
|
||||
}
|
||||
}),
|
||||
));
|
||||
@@ -757,7 +757,7 @@ pub trait InteractiveElement: Sized {
|
||||
fn on_boxed_action(
|
||||
mut self,
|
||||
action: &dyn Action,
|
||||
listener: impl Fn(&Box<dyn Action>, &mut WindowContext) + 'static,
|
||||
listener: impl Fn(&dyn Action, &mut WindowContext) + 'static,
|
||||
) -> Self {
|
||||
self.interactivity().on_boxed_action(action, listener);
|
||||
self
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
point, px, size, AbsoluteLength, Asset, Bounds, DefiniteLength, DevicePixels, Element,
|
||||
ElementId, GlobalElementId, Hitbox, ImageData, InteractiveElement, Interactivity, IntoElement,
|
||||
@@ -9,11 +5,20 @@ use crate::{
|
||||
WindowContext,
|
||||
};
|
||||
use futures::{AsyncReadExt, Future};
|
||||
use image::{ImageBuffer, ImageError};
|
||||
use http_client;
|
||||
use image::{
|
||||
codecs::gif::GifDecoder, AnimationDecoder, Frame, ImageBuffer, ImageError, ImageFormat,
|
||||
};
|
||||
#[cfg(target_os = "macos")]
|
||||
use media::core_video::CVImageBuffer;
|
||||
|
||||
use http_client;
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
fs,
|
||||
io::Cursor,
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use util::ResultExt;
|
||||
|
||||
@@ -230,8 +235,14 @@ impl Img {
|
||||
}
|
||||
}
|
||||
|
||||
/// The image state between frames
|
||||
struct ImgState {
|
||||
frame_index: usize,
|
||||
last_frame_time: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Element for Img {
|
||||
type RequestLayoutState = ();
|
||||
type RequestLayoutState = usize;
|
||||
type PrepaintState = Option<Hitbox>;
|
||||
|
||||
fn id(&self) -> Option<ElementId> {
|
||||
@@ -243,29 +254,65 @@ impl Element for Img {
|
||||
global_id: Option<&GlobalElementId>,
|
||||
cx: &mut WindowContext,
|
||||
) -> (LayoutId, Self::RequestLayoutState) {
|
||||
let layout_id = self
|
||||
.interactivity
|
||||
.request_layout(global_id, cx, |mut style, cx| {
|
||||
if let Some(data) = self.source.data(cx) {
|
||||
let image_size = data.size();
|
||||
match (style.size.width, style.size.height) {
|
||||
(Length::Auto, Length::Auto) => {
|
||||
style.size = Size {
|
||||
width: Length::Definite(DefiniteLength::Absolute(
|
||||
AbsoluteLength::Pixels(px(image_size.width.0 as f32)),
|
||||
)),
|
||||
height: Length::Definite(DefiniteLength::Absolute(
|
||||
AbsoluteLength::Pixels(px(image_size.height.0 as f32)),
|
||||
)),
|
||||
cx.with_optional_element_state(global_id, |state, cx| {
|
||||
let mut state = state.map(|state| {
|
||||
state.unwrap_or(ImgState {
|
||||
frame_index: 0,
|
||||
last_frame_time: None,
|
||||
})
|
||||
});
|
||||
|
||||
let frame_index = state.as_ref().map(|state| state.frame_index).unwrap_or(0);
|
||||
|
||||
let layout_id = self
|
||||
.interactivity
|
||||
.request_layout(global_id, cx, |mut style, cx| {
|
||||
if let Some(data) = self.source.data(cx) {
|
||||
if let Some(state) = &mut state {
|
||||
let frame_count = data.frame_count();
|
||||
if frame_count > 1 {
|
||||
let current_time = Instant::now();
|
||||
if let Some(last_frame_time) = state.last_frame_time {
|
||||
let elapsed = current_time - last_frame_time;
|
||||
let frame_duration =
|
||||
Duration::from(data.delay(state.frame_index));
|
||||
|
||||
if elapsed >= frame_duration {
|
||||
state.frame_index = (state.frame_index + 1) % frame_count;
|
||||
state.last_frame_time =
|
||||
Some(current_time - (elapsed - frame_duration));
|
||||
}
|
||||
} else {
|
||||
state.last_frame_time = Some(current_time);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
cx.request_layout(style, [])
|
||||
});
|
||||
(layout_id, ())
|
||||
let image_size = data.size(frame_index);
|
||||
match (style.size.width, style.size.height) {
|
||||
(Length::Auto, Length::Auto) => {
|
||||
style.size = Size {
|
||||
width: Length::Definite(DefiniteLength::Absolute(
|
||||
AbsoluteLength::Pixels(px(image_size.width.0 as f32)),
|
||||
)),
|
||||
height: Length::Definite(DefiniteLength::Absolute(
|
||||
AbsoluteLength::Pixels(px(image_size.height.0 as f32)),
|
||||
)),
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if global_id.is_some() && data.frame_count() > 1 {
|
||||
cx.request_animation_frame();
|
||||
}
|
||||
}
|
||||
|
||||
cx.request_layout(style, [])
|
||||
});
|
||||
|
||||
((layout_id, frame_index), state)
|
||||
})
|
||||
}
|
||||
|
||||
fn prepaint(
|
||||
@@ -283,7 +330,7 @@ impl Element for Img {
|
||||
&mut self,
|
||||
global_id: Option<&GlobalElementId>,
|
||||
bounds: Bounds<Pixels>,
|
||||
_: &mut Self::RequestLayoutState,
|
||||
frame_index: &mut Self::RequestLayoutState,
|
||||
hitbox: &mut Self::PrepaintState,
|
||||
cx: &mut WindowContext,
|
||||
) {
|
||||
@@ -293,9 +340,15 @@ impl Element for Img {
|
||||
let corner_radii = style.corner_radii.to_pixels(bounds.size, cx.rem_size());
|
||||
|
||||
if let Some(data) = source.data(cx) {
|
||||
let new_bounds = self.object_fit.get_bounds(bounds, data.size());
|
||||
cx.paint_image(new_bounds, corner_radii, data.clone(), self.grayscale)
|
||||
.log_err();
|
||||
let new_bounds = self.object_fit.get_bounds(bounds, data.size(*frame_index));
|
||||
cx.paint_image(
|
||||
new_bounds,
|
||||
corner_radii,
|
||||
data.clone(),
|
||||
*frame_index,
|
||||
self.grayscale,
|
||||
)
|
||||
.log_err();
|
||||
}
|
||||
|
||||
match source {
|
||||
@@ -385,12 +438,34 @@ impl Asset for Image {
|
||||
};
|
||||
|
||||
let data = if let Ok(format) = image::guess_format(&bytes) {
|
||||
let mut data = image::load_from_memory_with_format(&bytes, format)?.into_rgba8();
|
||||
let data = match format {
|
||||
ImageFormat::Gif => {
|
||||
let decoder = GifDecoder::new(Cursor::new(&bytes))?;
|
||||
let mut frames = SmallVec::new();
|
||||
|
||||
// Convert from RGBA to BGRA.
|
||||
for pixel in data.chunks_exact_mut(4) {
|
||||
pixel.swap(0, 2);
|
||||
}
|
||||
for frame in decoder.into_frames() {
|
||||
let mut frame = frame?;
|
||||
// Convert from RGBA to BGRA.
|
||||
for pixel in frame.buffer_mut().chunks_exact_mut(4) {
|
||||
pixel.swap(0, 2);
|
||||
}
|
||||
frames.push(frame);
|
||||
}
|
||||
|
||||
frames
|
||||
}
|
||||
_ => {
|
||||
let mut data =
|
||||
image::load_from_memory_with_format(&bytes, format)?.into_rgba8();
|
||||
|
||||
// Convert from RGBA to BGRA.
|
||||
for pixel in data.chunks_exact_mut(4) {
|
||||
pixel.swap(0, 2);
|
||||
}
|
||||
|
||||
SmallVec::from_elem(Frame::new(data), 1)
|
||||
}
|
||||
};
|
||||
|
||||
ImageData::new(data)
|
||||
} else {
|
||||
@@ -400,7 +475,7 @@ impl Asset for Image {
|
||||
let buffer =
|
||||
ImageBuffer::from_raw(pixmap.width(), pixmap.height(), pixmap.take()).unwrap();
|
||||
|
||||
ImageData::new(buffer)
|
||||
ImageData::new(SmallVec::from_elem(Frame::new(buffer), 1))
|
||||
};
|
||||
|
||||
Ok(Arc::new(data))
|
||||
|
||||
@@ -180,7 +180,7 @@ impl Transformation {
|
||||
}
|
||||
|
||||
fn into_matrix(self, center: Point<Pixels>, scale_factor: f32) -> TransformationMatrix {
|
||||
//Note: if you read this as a sequence of matrix mulitplications, start from the bottom
|
||||
//Note: if you read this as a sequence of matrix multiplications, start from the bottom
|
||||
TransformationMatrix::unit()
|
||||
.translate(center.scale(scale_factor) + self.translate.scale(scale_factor))
|
||||
.rotate(self.rotate)
|
||||
|
||||
@@ -325,7 +325,9 @@ impl UniformList {
|
||||
|
||||
let item_ix = cmp::min(self.item_to_measure_index, self.item_count - 1);
|
||||
let mut items = (self.render_items)(item_ix..item_ix + 1, cx);
|
||||
let mut item_to_measure = items.pop().unwrap();
|
||||
let Some(mut item_to_measure) = items.pop() else {
|
||||
return Size::default();
|
||||
};
|
||||
let available_space = size(
|
||||
list_width.map_or(AvailableSpace::MinContent, |width| {
|
||||
AvailableSpace::Definite(width)
|
||||
|
||||
@@ -940,6 +940,15 @@ where
|
||||
pub fn half_perimeter(&self) -> T {
|
||||
self.size.width.clone() + self.size.height.clone()
|
||||
}
|
||||
|
||||
/// centered_at creates a new bounds centered at the given point.
|
||||
pub fn centered_at(center: Point<T>, size: Size<T>) -> Self {
|
||||
let origin = Point {
|
||||
x: center.x - size.width.half(),
|
||||
y: center.y - size.height.half(),
|
||||
};
|
||||
Self::new(origin, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Default + Debug + PartialOrd + Add<T, Output = T> + Sub<Output = T>> Bounds<T> {
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
mod app_menu;
|
||||
mod keystroke;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
mod cosmic_text;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
mod linux;
|
||||
|
||||
@@ -51,8 +48,6 @@ use uuid::Uuid;
|
||||
pub use app_menu::*;
|
||||
pub use keystroke::*;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub(crate) use cosmic_text::*;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub(crate) use linux::*;
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -105,7 +100,6 @@ pub fn guess_compositor() -> &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
// todo("windows")
|
||||
#[cfg(target_os = "windows")]
|
||||
pub(crate) fn current_platform(_headless: bool) -> Rc<dyn Platform> {
|
||||
Rc::new(WindowsPlatform::new())
|
||||
@@ -261,7 +255,7 @@ pub enum Decorations {
|
||||
}
|
||||
|
||||
/// What window controls this platform supports
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Default)]
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct WindowControls {
|
||||
/// Whether this platform supports fullscreen
|
||||
pub fullscreen: bool,
|
||||
@@ -273,6 +267,18 @@ pub struct WindowControls {
|
||||
pub window_menu: bool,
|
||||
}
|
||||
|
||||
impl Default for WindowControls {
|
||||
fn default() -> Self {
|
||||
// Assume that we can do anything, unless told otherwise
|
||||
Self {
|
||||
fullscreen: true,
|
||||
maximize: true,
|
||||
minimize: true,
|
||||
window_menu: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A type to describe which sides of the window are currently tiled in some way
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Default)]
|
||||
pub struct Tiling {
|
||||
@@ -361,12 +367,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
|
||||
}
|
||||
fn set_app_id(&mut self, _app_id: &str) {}
|
||||
fn window_controls(&self) -> WindowControls {
|
||||
WindowControls {
|
||||
fullscreen: true,
|
||||
maximize: true,
|
||||
minimize: true,
|
||||
window_menu: false,
|
||||
}
|
||||
WindowControls::default()
|
||||
}
|
||||
fn set_client_inset(&self, _inset: Pixels) {}
|
||||
fn gpu_specs(&self) -> Option<GPUSpecs>;
|
||||
@@ -413,8 +414,6 @@ pub(crate) trait PlatformTextSystem: Send + Sync {
|
||||
raster_bounds: Bounds<DevicePixels>,
|
||||
) -> Result<(Size<DevicePixels>, Vec<u8>)>;
|
||||
fn layout_line(&self, text: &str, font_size: Pixels, runs: &[FontRun]) -> LineLayout;
|
||||
#[cfg(target_os = "windows")]
|
||||
fn destroy(&self);
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||
@@ -714,7 +713,6 @@ pub(crate) struct WindowParams {
|
||||
|
||||
pub display_id: Option<DisplayId>,
|
||||
|
||||
#[cfg_attr(target_os = "linux", allow(dead_code))]
|
||||
pub window_min_size: Option<Size<Pixels>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
mod text_system;
|
||||
|
||||
pub(crate) use text_system::*;
|
||||
@@ -237,14 +237,15 @@ pub struct Modifiers {
|
||||
}
|
||||
|
||||
impl Modifiers {
|
||||
/// Returns true if any modifier key is pressed
|
||||
/// Returns whether any modifier key is pressed.
|
||||
pub fn modified(&self) -> bool {
|
||||
self.control || self.alt || self.shift || self.platform || self.function
|
||||
}
|
||||
|
||||
/// Whether the semantically 'secondary' modifier key is pressed
|
||||
/// On macos, this is the command key
|
||||
/// On windows and linux, this is the control key
|
||||
/// Whether the semantically 'secondary' modifier key is pressed.
|
||||
///
|
||||
/// On macOS, this is the command key.
|
||||
/// On Linux and Windows, this is the control key.
|
||||
pub fn secondary(&self) -> bool {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
@@ -257,7 +258,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// How many modifier keys are pressed
|
||||
/// Returns how many modifier keys are pressed.
|
||||
pub fn number_of_modifiers(&self) -> u8 {
|
||||
self.control as u8
|
||||
+ self.alt as u8
|
||||
@@ -266,12 +267,12 @@ impl Modifiers {
|
||||
+ self.function as u8
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with no modifiers
|
||||
/// Returns [`Modifiers`] with no modifiers.
|
||||
pub fn none() -> Modifiers {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with just the command key
|
||||
/// Returns [`Modifiers`] with just the command key.
|
||||
pub fn command() -> Modifiers {
|
||||
Modifiers {
|
||||
platform: true,
|
||||
@@ -279,7 +280,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// A helper method for Modifiers with just the secondary key pressed
|
||||
/// A Returns [`Modifiers`] with just the secondary key pressed.
|
||||
pub fn secondary_key() -> Modifiers {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
@@ -298,7 +299,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with just the windows key
|
||||
/// Returns [`Modifiers`] with just the windows key.
|
||||
pub fn windows() -> Modifiers {
|
||||
Modifiers {
|
||||
platform: true,
|
||||
@@ -306,7 +307,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with just the super key
|
||||
/// Returns [`Modifiers`] with just the super key.
|
||||
pub fn super_key() -> Modifiers {
|
||||
Modifiers {
|
||||
platform: true,
|
||||
@@ -314,7 +315,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with just control
|
||||
/// Returns [`Modifiers`] with just control.
|
||||
pub fn control() -> Modifiers {
|
||||
Modifiers {
|
||||
control: true,
|
||||
@@ -322,7 +323,15 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with just shift
|
||||
/// Returns [`Modifiers`] with just control.
|
||||
pub fn alt() -> Modifiers {
|
||||
Modifiers {
|
||||
alt: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns [`Modifiers`] with just shift.
|
||||
pub fn shift() -> Modifiers {
|
||||
Modifiers {
|
||||
shift: true,
|
||||
@@ -330,7 +339,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with command + shift
|
||||
/// Returns [`Modifiers`] with command + shift.
|
||||
pub fn command_shift() -> Modifiers {
|
||||
Modifiers {
|
||||
shift: true,
|
||||
@@ -339,7 +348,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// helper method for Modifiers with command + shift
|
||||
/// Returns [`Modifiers`] with command + shift.
|
||||
pub fn control_shift() -> Modifiers {
|
||||
Modifiers {
|
||||
shift: true,
|
||||
@@ -348,7 +357,7 @@ impl Modifiers {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if this Modifiers is a subset of another Modifiers
|
||||
/// Checks if this [`Modifiers`] is a subset of another [`Modifiers`].
|
||||
pub fn is_subset_of(&self, other: &Modifiers) -> bool {
|
||||
(other.control || !self.control)
|
||||
&& (other.alt || !self.alt)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod dispatcher;
|
||||
mod headless;
|
||||
mod platform;
|
||||
mod text_system;
|
||||
mod wayland;
|
||||
mod x11;
|
||||
mod xdg_desktop_portal;
|
||||
@@ -8,5 +9,6 @@ mod xdg_desktop_portal;
|
||||
pub(crate) use dispatcher::*;
|
||||
pub(crate) use headless::*;
|
||||
pub(crate) use platform::*;
|
||||
pub(crate) use text_system::*;
|
||||
pub(crate) use wayland::*;
|
||||
pub(crate) use x11::*;
|
||||
|
||||
@@ -64,13 +64,17 @@ impl PlatformTextSystem for CosmicTextSystem {
|
||||
}
|
||||
|
||||
fn all_font_names(&self) -> Vec<String> {
|
||||
self.0
|
||||
let mut result = self
|
||||
.0
|
||||
.read()
|
||||
.font_system
|
||||
.db()
|
||||
.faces()
|
||||
.map(|face| face.post_script_name.clone())
|
||||
.collect()
|
||||
.filter_map(|face| face.families.first().map(|family| family.0.clone()))
|
||||
.collect_vec();
|
||||
result.sort();
|
||||
result.dedup();
|
||||
result
|
||||
}
|
||||
|
||||
fn all_font_families(&self) -> Vec<String> {
|
||||
@@ -177,9 +181,6 @@ impl PlatformTextSystem for CosmicTextSystem {
|
||||
fn layout_line(&self, text: &str, font_size: Pixels, runs: &[FontRun]) -> LineLayout {
|
||||
self.0.write().layout_line(text, font_size, runs)
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn destroy(&self) {}
|
||||
}
|
||||
|
||||
impl CosmicTextSystemState {
|
||||
@@ -395,6 +395,7 @@ impl WaylandClient {
|
||||
let qh = event_queue.handle();
|
||||
|
||||
let mut seat: Option<wl_seat::WlSeat> = None;
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
let mut in_progress_outputs = HashMap::default();
|
||||
globals.contents().with_list(|list| {
|
||||
for global in list {
|
||||
@@ -874,6 +875,7 @@ impl Dispatch<wl_surface::WlSurface, ()> for WaylandClientStatePtr {
|
||||
let Some(window) = get_window(&mut state, &surface.id()) else {
|
||||
return;
|
||||
};
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
let outputs = state.outputs.clone();
|
||||
drop(state);
|
||||
|
||||
|
||||
@@ -185,13 +185,7 @@ impl WaylandWindowState {
|
||||
active: false,
|
||||
hovered: false,
|
||||
in_progress_window_controls: None,
|
||||
// Assume that we can do anything, unless told otherwise
|
||||
window_controls: WindowControls {
|
||||
fullscreen: true,
|
||||
maximize: true,
|
||||
minimize: true,
|
||||
window_menu: true,
|
||||
},
|
||||
window_controls: WindowControls::default(),
|
||||
inset: None,
|
||||
})
|
||||
}
|
||||
@@ -264,7 +258,10 @@ impl WaylandWindow {
|
||||
.wm_base
|
||||
.get_xdg_surface(&surface, &globals.qh, surface.id());
|
||||
let toplevel = xdg_surface.get_toplevel(&globals.qh, surface.id());
|
||||
toplevel.set_min_size(50, 50);
|
||||
|
||||
if let Some(size) = params.window_min_size {
|
||||
toplevel.set_min_size(size.width.0 as i32, size.height.0 as i32);
|
||||
}
|
||||
|
||||
if let Some(fractional_scale_manager) = globals.fractional_scale_manager.as_ref() {
|
||||
fractional_scale_manager.get_fractional_scale(&surface, &globals.qh, surface.id());
|
||||
@@ -545,6 +542,7 @@ impl WaylandWindowStatePtr {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::mutable_key_type)]
|
||||
pub fn handle_surface_event(
|
||||
&self,
|
||||
event: wl_surface::Event,
|
||||
|
||||
@@ -111,7 +111,9 @@ pub struct X11ClientState {
|
||||
|
||||
pub(crate) scale_factor: f32,
|
||||
|
||||
xkb_context: xkbc::Context,
|
||||
pub(crate) xcb_connection: Rc<XCBConnection>,
|
||||
xkb_device_id: i32,
|
||||
client_side_decorations_supported: bool,
|
||||
pub(crate) x_root_index: usize,
|
||||
pub(crate) _resource_database: Database,
|
||||
@@ -253,7 +255,9 @@ impl X11Client {
|
||||
.reply()
|
||||
.unwrap();
|
||||
|
||||
let events = xkb::EventType::STATE_NOTIFY;
|
||||
let events = xkb::EventType::STATE_NOTIFY
|
||||
| xkb::EventType::MAP_NOTIFY
|
||||
| xkb::EventType::NEW_KEYBOARD_NOTIFY;
|
||||
xcb_connection
|
||||
.xkb_select_events(
|
||||
xkb::ID::USE_CORE_KBD.into(),
|
||||
@@ -267,8 +271,8 @@ impl X11Client {
|
||||
assert!(xkb.supported);
|
||||
|
||||
let xkb_context = xkbc::Context::new(xkbc::CONTEXT_NO_FLAGS);
|
||||
let xkb_device_id = xkbc::x11::get_core_keyboard_device_id(&xcb_connection);
|
||||
let xkb_state = {
|
||||
let xkb_device_id = xkbc::x11::get_core_keyboard_device_id(&xcb_connection);
|
||||
let xkb_keymap = xkbc::x11::keymap_new_from_device(
|
||||
&xkb_context,
|
||||
&xcb_connection,
|
||||
@@ -349,7 +353,9 @@ impl X11Client {
|
||||
current_count: 0,
|
||||
scale_factor,
|
||||
|
||||
xkb_context,
|
||||
xcb_connection,
|
||||
xkb_device_id,
|
||||
client_side_decorations_supported,
|
||||
x_root_index,
|
||||
_resource_database: resource_database,
|
||||
@@ -621,6 +627,23 @@ impl X11Client {
|
||||
self.disable_ime();
|
||||
window.handle_ime_delete();
|
||||
}
|
||||
Event::XkbNewKeyboardNotify(_) | Event::MapNotify(_) => {
|
||||
let mut state = self.0.borrow_mut();
|
||||
let xkb_state = {
|
||||
let xkb_keymap = xkbc::x11::keymap_new_from_device(
|
||||
&state.xkb_context,
|
||||
&state.xcb_connection,
|
||||
state.xkb_device_id,
|
||||
xkbc::KEYMAP_COMPILE_NO_FLAGS,
|
||||
);
|
||||
xkbc::x11::state_new_from_device(
|
||||
&xkb_keymap,
|
||||
&state.xcb_connection,
|
||||
state.xkb_device_id,
|
||||
)
|
||||
};
|
||||
state.xkb = xkb_state;
|
||||
}
|
||||
Event::XkbStateNotify(event) => {
|
||||
let mut state = self.0.borrow_mut();
|
||||
state.xkb.update_mask(
|
||||
|
||||
@@ -14,6 +14,7 @@ use raw_window_handle as rwh;
|
||||
use util::{maybe, ResultExt};
|
||||
use x11rb::{
|
||||
connection::Connection,
|
||||
properties::WmSizeHints,
|
||||
protocol::{
|
||||
sync,
|
||||
xinput::{self, ConnectionExt as _},
|
||||
@@ -371,6 +372,14 @@ impl X11WindowState {
|
||||
visual.depth, x_window, visual_set.root, bounds.origin.x.0 + 2, bounds.origin.y.0, bounds.size.width.0, bounds.size.height.0)
|
||||
})?;
|
||||
|
||||
if let Some(size) = params.window_min_size {
|
||||
let mut size_hints = WmSizeHints::new();
|
||||
size_hints.min_size = Some((size.width.0 as i32, size.height.0 as i32));
|
||||
size_hints
|
||||
.set_normal_hints(xcb_connection, x_window)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let reply = xcb_connection
|
||||
.get_geometry(x_window)
|
||||
.unwrap()
|
||||
@@ -1221,6 +1230,14 @@ impl PlatformWindow for X11Window {
|
||||
|
||||
fn show_window_menu(&self, position: Point<Pixels>) {
|
||||
let state = self.0.state.borrow();
|
||||
|
||||
self.0
|
||||
.xcb_connection
|
||||
.ungrab_pointer(x11rb::CURRENT_TIME)
|
||||
.unwrap()
|
||||
.check()
|
||||
.unwrap();
|
||||
|
||||
let coords = self.get_root_position(position);
|
||||
let message = ClientMessageEvent::new(
|
||||
32,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
#![allow(unused, non_upper_case_globals)]
|
||||
|
||||
use crate::FontFeatures;
|
||||
use crate::{FontFallbacks, FontFeatures};
|
||||
use cocoa::appkit::CGFloat;
|
||||
use core_foundation::{
|
||||
array::{
|
||||
kCFTypeArrayCallBacks, CFArray, CFArrayAppendValue, CFArrayCreateMutable, CFMutableArrayRef,
|
||||
kCFTypeArrayCallBacks, CFArray, CFArrayAppendArray, CFArrayAppendValue,
|
||||
CFArrayCreateMutable, CFArrayGetCount, CFArrayGetValueAtIndex, CFArrayRef,
|
||||
CFMutableArrayRef,
|
||||
},
|
||||
base::{kCFAllocatorDefault, CFRelease, TCFType},
|
||||
dictionary::{
|
||||
@@ -13,21 +15,88 @@ use core_foundation::{
|
||||
number::CFNumber,
|
||||
string::{CFString, CFStringRef},
|
||||
};
|
||||
use core_foundation_sys::locale::CFLocaleCopyPreferredLanguages;
|
||||
use core_graphics::{display::CFDictionary, geometry::CGAffineTransform};
|
||||
use core_text::{
|
||||
font::{CTFont, CTFontRef},
|
||||
font::{cascade_list_for_languages, CTFont, CTFontRef},
|
||||
font_descriptor::{
|
||||
kCTFontFeatureSettingsAttribute, CTFontDescriptor, CTFontDescriptorCopyAttributes,
|
||||
CTFontDescriptorCreateCopyWithFeature, CTFontDescriptorCreateWithAttributes,
|
||||
kCTFontCascadeListAttribute, kCTFontFeatureSettingsAttribute, CTFontDescriptor,
|
||||
CTFontDescriptorCopyAttributes, CTFontDescriptorCreateCopyWithFeature,
|
||||
CTFontDescriptorCreateWithAttributes, CTFontDescriptorCreateWithNameAndSize,
|
||||
CTFontDescriptorRef,
|
||||
},
|
||||
};
|
||||
use font_kit::font::Font;
|
||||
use font_kit::font::Font as FontKitFont;
|
||||
use std::ptr;
|
||||
|
||||
pub fn apply_features(font: &mut Font, features: &FontFeatures) {
|
||||
pub fn apply_features_and_fallbacks(
|
||||
font: &mut FontKitFont,
|
||||
features: &FontFeatures,
|
||||
fallbacks: Option<&FontFallbacks>,
|
||||
) -> anyhow::Result<()> {
|
||||
unsafe {
|
||||
let fallback_array = CFArrayCreateMutable(kCFAllocatorDefault, 0, &kCFTypeArrayCallBacks);
|
||||
|
||||
if let Some(fallbacks) = fallbacks {
|
||||
for user_fallback in fallbacks.fallback_list() {
|
||||
let name = CFString::from(user_fallback.as_str());
|
||||
let fallback_desc =
|
||||
CTFontDescriptorCreateWithNameAndSize(name.as_concrete_TypeRef(), 0.0);
|
||||
CFArrayAppendValue(fallback_array, fallback_desc as _);
|
||||
CFRelease(fallback_desc as _);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let preferred_languages: CFArray<CFString> =
|
||||
CFArray::wrap_under_create_rule(CFLocaleCopyPreferredLanguages());
|
||||
|
||||
let default_fallbacks = CTFontCopyDefaultCascadeListForLanguages(
|
||||
font.native_font().as_concrete_TypeRef(),
|
||||
preferred_languages.as_concrete_TypeRef(),
|
||||
);
|
||||
let default_fallbacks: CFArray<CTFontDescriptor> =
|
||||
CFArray::wrap_under_create_rule(default_fallbacks);
|
||||
|
||||
default_fallbacks
|
||||
.iter()
|
||||
.filter(|desc| desc.font_path().is_some())
|
||||
.map(|desc| {
|
||||
CFArrayAppendValue(fallback_array, desc.as_concrete_TypeRef() as _);
|
||||
});
|
||||
}
|
||||
|
||||
let feature_array = generate_feature_array(features);
|
||||
let keys = [kCTFontFeatureSettingsAttribute, kCTFontCascadeListAttribute];
|
||||
let values = [feature_array, fallback_array];
|
||||
let attrs = CFDictionaryCreate(
|
||||
kCFAllocatorDefault,
|
||||
keys.as_ptr() as _,
|
||||
values.as_ptr() as _,
|
||||
2,
|
||||
&kCFTypeDictionaryKeyCallBacks,
|
||||
&kCFTypeDictionaryValueCallBacks,
|
||||
);
|
||||
CFRelease(feature_array as *const _ as _);
|
||||
CFRelease(fallback_array as *const _ as _);
|
||||
let new_descriptor = CTFontDescriptorCreateWithAttributes(attrs);
|
||||
CFRelease(attrs as _);
|
||||
let new_descriptor = CTFontDescriptor::wrap_under_create_rule(new_descriptor);
|
||||
let new_font = CTFontCreateCopyWithAttributes(
|
||||
font.native_font().as_concrete_TypeRef(),
|
||||
0.0,
|
||||
std::ptr::null(),
|
||||
new_descriptor.as_concrete_TypeRef(),
|
||||
);
|
||||
let new_font = CTFont::wrap_under_create_rule(new_font);
|
||||
*font = font_kit::font::Font::from_native_font(&new_font);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_feature_array(features: &FontFeatures) -> CFMutableArrayRef {
|
||||
unsafe {
|
||||
let native_font = font.native_font();
|
||||
let mut feature_array =
|
||||
CFArrayCreateMutable(kCFAllocatorDefault, 0, &kCFTypeArrayCallBacks);
|
||||
for (tag, value) in features.tag_value_list() {
|
||||
@@ -48,26 +117,7 @@ pub fn apply_features(font: &mut Font, features: &FontFeatures) {
|
||||
CFArrayAppendValue(feature_array, dict as _);
|
||||
CFRelease(dict as _);
|
||||
}
|
||||
let attrs = CFDictionaryCreate(
|
||||
kCFAllocatorDefault,
|
||||
&kCTFontFeatureSettingsAttribute as *const _ as _,
|
||||
&feature_array as *const _ as _,
|
||||
1,
|
||||
&kCFTypeDictionaryKeyCallBacks,
|
||||
&kCFTypeDictionaryValueCallBacks,
|
||||
);
|
||||
CFRelease(feature_array as *const _ as _);
|
||||
let new_descriptor = CTFontDescriptorCreateWithAttributes(attrs);
|
||||
CFRelease(attrs as _);
|
||||
let new_descriptor = CTFontDescriptor::wrap_under_create_rule(new_descriptor);
|
||||
let new_font = CTFontCreateCopyWithAttributes(
|
||||
font.native_font().as_concrete_TypeRef(),
|
||||
0.0,
|
||||
ptr::null(),
|
||||
new_descriptor.as_concrete_TypeRef(),
|
||||
);
|
||||
let new_font = CTFont::wrap_under_create_rule(new_font);
|
||||
*font = Font::from_native_font(&new_font);
|
||||
feature_array
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,4 +132,8 @@ extern "C" {
|
||||
matrix: *const CGAffineTransform,
|
||||
attributes: CTFontDescriptorRef,
|
||||
) -> CTFontRef;
|
||||
fn CTFontCopyDefaultCascadeListForLanguages(
|
||||
font: CTFontRef,
|
||||
languagePrefList: CFArrayRef,
|
||||
) -> CFArrayRef;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
point, px, size, Bounds, DevicePixels, Font, FontFeatures, FontId, FontMetrics, FontRun,
|
||||
FontStyle, FontWeight, GlyphId, LineLayout, Pixels, PlatformTextSystem, Point,
|
||||
point, px, size, Bounds, DevicePixels, Font, FontFallbacks, FontFeatures, FontId, FontMetrics,
|
||||
FontRun, FontStyle, FontWeight, GlyphId, LineLayout, Pixels, PlatformTextSystem, Point,
|
||||
RenderGlyphParams, Result, ShapedGlyph, ShapedRun, SharedString, Size, SUBPIXEL_VARIANTS,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
@@ -43,7 +43,7 @@ use pathfinder_geometry::{
|
||||
use smallvec::SmallVec;
|
||||
use std::{borrow::Cow, char, cmp, convert::TryFrom, sync::Arc};
|
||||
|
||||
use super::open_type;
|
||||
use super::open_type::apply_features_and_fallbacks;
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const kCGImageAlphaOnly: u32 = 7;
|
||||
@@ -54,6 +54,7 @@ pub(crate) struct MacTextSystem(RwLock<MacTextSystemState>);
|
||||
struct FontKey {
|
||||
font_family: SharedString,
|
||||
font_features: FontFeatures,
|
||||
font_fallbacks: Option<FontFallbacks>,
|
||||
}
|
||||
|
||||
struct MacTextSystemState {
|
||||
@@ -123,11 +124,13 @@ impl PlatformTextSystem for MacTextSystem {
|
||||
let font_key = FontKey {
|
||||
font_family: font.family.clone(),
|
||||
font_features: font.features.clone(),
|
||||
font_fallbacks: font.fallbacks.clone(),
|
||||
};
|
||||
let candidates = if let Some(font_ids) = lock.font_ids_by_font_key.get(&font_key) {
|
||||
font_ids.as_slice()
|
||||
} else {
|
||||
let font_ids = lock.load_family(&font.family, &font.features)?;
|
||||
let font_ids =
|
||||
lock.load_family(&font.family, &font.features, font.fallbacks.as_ref())?;
|
||||
lock.font_ids_by_font_key.insert(font_key.clone(), font_ids);
|
||||
lock.font_ids_by_font_key[&font_key].as_ref()
|
||||
};
|
||||
@@ -212,6 +215,7 @@ impl MacTextSystemState {
|
||||
&mut self,
|
||||
name: &str,
|
||||
features: &FontFeatures,
|
||||
fallbacks: Option<&FontFallbacks>,
|
||||
) -> Result<SmallVec<[FontId; 4]>> {
|
||||
let name = if name == ".SystemUIFont" {
|
||||
".AppleSystemUIFont"
|
||||
@@ -227,8 +231,7 @@ impl MacTextSystemState {
|
||||
for font in family.fonts() {
|
||||
let mut font = font.load()?;
|
||||
|
||||
open_type::apply_features(&mut font, features);
|
||||
|
||||
apply_features_and_fallbacks(&mut font, features, fallbacks)?;
|
||||
// This block contains a precautionary fix to guard against loading fonts
|
||||
// that might cause panics due to `.unwrap()`s up the chain.
|
||||
{
|
||||
@@ -457,6 +460,7 @@ impl MacTextSystemState {
|
||||
CFRange::init(utf16_start as isize, (utf16_end - utf16_start) as isize);
|
||||
|
||||
let font: &FontKitFont = &self.fonts[run.font_id.0];
|
||||
|
||||
unsafe {
|
||||
string.set_attribute(
|
||||
cf_range,
|
||||
@@ -634,7 +638,7 @@ impl From<FontStyle> for FontkitStyle {
|
||||
}
|
||||
}
|
||||
|
||||
// Some fonts may have no attributest despite `core_text` requiring them (and panicking).
|
||||
// Some fonts may have no attributes despite `core_text` requiring them (and panicking).
|
||||
// This is the same version as `core_text` has without `expect` calls.
|
||||
mod lenient_font_attributes {
|
||||
use core_foundation::{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user