Co-Authored-By: Antonio Scandurra <antonio@zed.dev>
Co-Authored-By: Kyle Kelley <kylek@zed.dev>
This commit is contained in:
Nathan Sobo
2024-04-11 09:15:32 -06:00
parent b857beb2c6
commit 385da79021
9 changed files with 218 additions and 20 deletions

5
Cargo.lock generated
View File

@@ -381,10 +381,15 @@ dependencies = [
name = "assistant2"
version = "0.1.0"
dependencies = [
"anyhow",
"assets",
"client",
"editor",
"env_logger",
"futures 0.3.28",
"gpui",
"language",
"release_channel",
"semantic_index",
"settings",
"theme",

View File

@@ -209,7 +209,13 @@
}
},
{
"context": "AssistantPanel",
"context": "AssistantChat > Editor", // Used in the assistant2 crate
"bindings": {
"enter": "assistant::Submit"
}
},
{
"context": "AssistantPanel", // Used in the assistant crate, which we're replacing
"bindings": {
"cmd-g": "search::SelectNextMatch",
"cmd-shift-g": "search::SelectPrevMatch"

View File

@@ -9,7 +9,10 @@ path = "src/assistant2.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow.workspace = true
client.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
semantic_index.workspace = true
@@ -21,8 +24,10 @@ workspace.workspace = true
[dev-dependencies]
assets.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
release_channel.workspace = true
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }

View File

@@ -1,11 +1,13 @@
use assets::Assets;
use assistant2::AssistantPanel;
use client::Client;
use gpui::{App, View, WindowOptions};
use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
use theme::LoadThemes;
use ui::{div, prelude::*, Render};
fn main() {
env_logger::init();
App::new().with_assets(Assets).run(|cx| {
settings::init(cx);
language::init(cx);
@@ -13,6 +15,16 @@ fn main() {
theme::init(LoadThemes::JustBase, cx);
Assets.load_fonts(cx).unwrap();
KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
client::init_settings(cx);
release_channel::init("0.0.0", cx);
let client = Client::production(cx);
{
let client = client.clone();
cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
.detach_and_log_err(cx);
}
assistant2::init(client, cx);
cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(cx))

View File

@@ -1,11 +1,29 @@
mod completion_provider;
use std::sync::Arc;
use client::Client;
use completion_provider::*;
use editor::Editor;
use gpui::{list, AnyElement, ListAlignment, ListState, Render, View};
use futures::StreamExt;
use gpui::{
list, prelude::IntoElement, AnyElement, AppContext, Global, ListAlignment, ListState, Render,
View,
};
use language::language_settings::SoftWrap;
use semantic_index::SearchResult;
use settings::Settings;
use theme::ThemeSettings;
use ui::prelude::*;
gpui::actions!(assistant, [Submit]);
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
cx.set_global(CompletionProvider::new(CloudCompletionProvider::new(
client,
)));
}
pub struct AssistantPanel {
chat: View<AssistantChat>,
}
@@ -62,9 +80,37 @@ impl AssistantChat {
}
}
fn submit(&mut self, _: &Submit, cx: &mut ViewContext<Self>) {
// Detect which message is focused and send the ones above it
//
let completion = CompletionProvider::get(cx).complete(
"openai/gpt-4-turbo-preview".to_string(),
self.messages(cx),
Vec::new(),
1.0,
);
cx.spawn(|this, cx| async move {
dbg!();
let mut stream = completion.await?;
dbg!();
while let Some(chunk) = stream.next().await {
dbg!();
let text = chunk?;
dbg!(text);
}
dbg!();
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
match &self.messages[ix] {
AssistantMessage::User { body, contexts } => div()
.on_action(cx.listener(Self::submit))
.p_2()
.text_color(cx.theme().colors().editor_foreground)
.font(ThemeSettings::get_global(cx).buffer_font.clone())
@@ -74,14 +120,31 @@ impl AssistantChat {
AssistantMessage::Assistant { body } => body.clone().into_any_element(),
}
}
fn messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
self.messages
.iter()
.map(|message| match message {
AssistantMessage::User { body, contexts } => CompletionMessage {
role: CompletionRole::User,
body: body.read(cx).text(cx),
},
AssistantMessage::Assistant { body } => CompletionMessage {
role: CompletionRole::Assistant,
body: body.to_string(),
},
})
.collect()
}
}
impl Render for AssistantChat {
fn render(
&mut self,
cx: &mut workspace::ui::prelude::ViewContext<Self>,
) -> impl gpui::prelude::IntoElement {
list(self.list_state.clone()).size_full()
fn render(&mut self, cx: &mut workspace::ui::prelude::ViewContext<Self>) -> impl IntoElement {
div()
.flex_1()
.v_flex()
.key_context("AssistantChat")
.child(list(self.list_state.clone()).flex_1())
}
}

View File

@@ -0,0 +1,105 @@
use anyhow::Result;
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::Global;
use std::sync::Arc;
pub enum CompletionRole {
User,
Assistant,
System,
}
pub struct CompletionMessage {
pub role: CompletionRole,
pub body: String,
}
#[derive(Clone)]
pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
impl CompletionProvider {
pub fn new(backend: impl CompletionProviderBackend) -> Self {
Self(Arc::new(backend))
}
pub fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
self.0.complete(model, messages, stop, temperature)
}
}
impl Global for CompletionProvider {}
pub trait CompletionProviderBackend: 'static {
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
}
pub struct CloudCompletionProvider {
client: Arc<Client>,
}
impl CloudCompletionProvider {
pub fn new(client: Arc<Client>) -> Self {
Self { client }
}
}
impl CompletionProviderBackend for CloudCompletionProvider {
fn complete(
&self,
model: String,
messages: Vec<CompletionMessage>,
stop: Vec<String>,
temperature: f32,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let client = self.client.clone();
async move {
let stream = client
.request_stream(proto::CompleteWithLanguageModel {
model,
messages: messages
.into_iter()
.map(|message| proto::LanguageModelRequestMessage {
role: match message.role {
CompletionRole::User => {
proto::LanguageModelRole::LanguageModelUser as i32
}
CompletionRole::Assistant => {
proto::LanguageModelRole::LanguageModelAssistant as i32
}
CompletionRole::System => {
proto::LanguageModelRole::LanguageModelSystem as i32
}
},
content: message.body,
})
.collect(),
stop,
temperature,
})
.await?;
Ok(stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed())
}
.boxed()
}
}

View File

@@ -457,6 +457,14 @@ impl Client {
})
}
pub fn production(cx: &mut AppContext) -> Arc<Self> {
let clock = Arc::new(clock::RealSystemClock);
let http = Arc::new(HttpClientWithUrl::new(
&ClientSettings::get_global(cx).server_url,
));
Self::new(clock, http.clone(), cx)
}
pub fn id(&self) -> u64 {
self.id.load(Ordering::SeqCst)
}

View File

@@ -10141,7 +10141,7 @@ impl Render for Editor {
let settings = ThemeSettings::get_global(cx);
let text_style = match self.mode {
EditorMode::SingleLine | EditorMode::AutoHeight { .. } => dbg!(cx.text_style()),
EditorMode::SingleLine | EditorMode::AutoHeight { .. } => cx.text_style(),
EditorMode::Full => TextStyle {
color: cx.theme().colors().editor_foreground,
font_family: settings.buffer_font.family.clone(),

View File

@@ -169,20 +169,14 @@ fn main() {
settings::init(cx);
handle_settings_file_changes(user_settings_file_rx, cx);
handle_keymap_file_changes(user_keymap_file_rx, cx);
client::init_settings(cx);
let clock = Arc::new(clock::RealSystemClock);
let http = Arc::new(HttpClientWithUrl::new(
&client::ClientSettings::get_global(cx).server_url,
));
let client = client::Client::new(clock, http.clone(), cx);
let client = Client::production(cx);
let mut languages =
LanguageRegistry::new(login_shell_env_loaded, cx.background_executor().clone());
let copilot_language_server_id = languages.next_language_server_id();
languages.set_language_server_download_dir(paths::LANGUAGES_DIR.clone());
let languages = Arc::new(languages);
let node_runtime = RealNodeRuntime::new(http.clone());
let node_runtime = RealNodeRuntime::new(client.http_client());
language::init(cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
@@ -202,7 +196,7 @@ fn main() {
diagnostics::init(cx);
copilot::init(
copilot_language_server_id,
http.clone(),
client.http_client(),
node_runtime.clone(),
cx,
);
@@ -227,7 +221,7 @@ fn main() {
cx.observe_global::<SettingsStore>({
let languages = languages.clone();
let http = http.clone();
let http = client.http_client();
let client = client.clone();
move |cx| {
@@ -276,7 +270,7 @@ fn main() {
AppState::set_global(Arc::downgrade(&app_state), cx);
audio::init(Assets, cx);
auto_update::init(http.clone(), cx);
auto_update::init(client.http_client(), cx);
workspace::init(app_state.clone(), cx);
recent_projects::init(cx);
@@ -309,7 +303,7 @@ fn main() {
initialize_workspace(app_state.clone(), cx);
// todo(linux): unblock this
upload_panics_and_crashes(http.clone(), cx);
upload_panics_and_crashes(client.http_client(), cx);
cx.activate(true);