From 405f7cf64fa4a0edbfef314cb8f617f930fde70d Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Wed, 2 Jul 2025 16:49:37 -0600 Subject: [PATCH] Hack in authentication Co-authored-by: Mikayla Maki --- crates/acp/src/server.rs | 20 +++++- crates/acp/src/thread_view.rs | 122 ++++++++++++++++++++-------------- 2 files changed, 89 insertions(+), 53 deletions(-) diff --git a/crates/acp/src/server.rs b/crates/acp/src/server.rs index 983b329041..6881e347cf 100644 --- a/crates/acp/src/server.rs +++ b/crates/acp/src/server.rs @@ -238,13 +238,13 @@ impl acp::Client for AcpClientDelegate { } impl AcpServer { - pub fn stdio(mut process: Child, project: Entity, cx: &mut AsyncApp) -> Arc { + pub fn stdio(mut process: Child, project: Entity, cx: &mut App) -> Arc { let stdin = process.stdin.take().expect("process didn't have stdin"); let stdout = process.stdout.take().expect("process didn't have stdout"); let threads: Arc>>> = Default::default(); let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent( - AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()), + AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()), stdin, stdout, ); @@ -269,6 +269,22 @@ impl AcpServer { }) } + pub async fn initialize(&self) -> Result { + self.connection + .request(acp::InitializeParams) + .await + .map_err(to_anyhow) + } + + pub async fn authenticate(&self) -> Result<()> { + self.connection + .request(acp::AuthenticateParams) + .await + .map_err(to_anyhow)?; + + Ok(()) + } + pub async fn create_thread(self: Arc, cx: &mut AsyncApp) -> Result> { let response = self .connection diff --git a/crates/acp/src/thread_view.rs b/crates/acp/src/thread_view.rs index ed4bafb46c..96bae43c2e 100644 --- a/crates/acp/src/thread_view.rs +++ b/crates/acp/src/thread_view.rs @@ -1,5 +1,6 @@ use std::path::Path; use std::rc::Rc; +use std::sync::Arc; use std::time::Duration; use agentic_coding_protocol::{self as acp, ToolCallConfirmation}; @@ -12,20 +13,14 @@ use gpui::{ }; use gpui::{FocusHandle, Task}; use language::Buffer; -<<<<<<< HEAD use language::language_settings::SoftWrap; -use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle}; -||||||| parent of 47b80cc740 (Show errors from ACP when requests error) -use markdown::{HeadingLevelStyles, MarkdownElement, MarkdownStyle}; -======= use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; ->>>>>>> 47b80cc740 (Show errors from ACP when requests error) use project::Project; use settings::Settings as _; use theme::ThemeSettings; use ui::prelude::*; use ui::{Button, Tooltip}; -use util::ResultExt; +use util::{ResultExt, paths}; use zed_actions::agent::Chat; use crate::{ @@ -34,6 +29,7 @@ use crate::{ }; pub struct AcpThreadView { + agent: Arc, thread_state: ThreadState, // todo! reconsider structure. currently pretty sparse, but easy to clean up if we need to delete entries. thread_entry_views: Vec>, @@ -41,6 +37,7 @@ pub struct AcpThreadView { last_error: Option>, list_state: ListState, send_task: Option>>, + auth_task: Option>, } #[derive(Debug)] @@ -57,6 +54,7 @@ enum ThreadState { _subscription: Subscription, }, LoadError(SharedString), + Unauthenticated, } impl AcpThreadView { @@ -107,21 +105,12 @@ impl AcpThreadView { } } - fn initial_state( - project: Entity, - window: &mut Window, - cx: &mut Context, - ) -> ThreadState { - let Some(root_dir) = project + let root_dir = project .read(cx) .visible_worktrees(cx) .next() .map(|worktree| worktree.read(cx).abs_path()) - else { - return ThreadState::LoadError( - "Gemini threads must be created within a project".into(), - ); - }; + .unwrap_or_else(|| paths::home_dir().as_path().into()); let cli_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli"); @@ -137,45 +126,44 @@ impl AcpThreadView { .spawn() .unwrap(); - let project = project.clone(); + let agent = AcpServer::stdio(child, project, cx); + + Self { + thread_state: Self::initial_state(agent.clone(), window, cx), + agent, + message_editor, + send_task: None, + list_state: list_state, + last_error: None, + auth_task: None, + } + } + + fn initial_state( + agent: Arc, + window: &mut Window, + cx: &mut Context, + ) -> ThreadState { let load_task = cx.spawn_in(window, async move |this, cx| { - let agent = AcpServer::stdio(child, project, cx); - let result = agent.clone().create_thread(cx).await; + let result = match agent.initialize().await { + Err(e) => Err(e), + Ok(response) => { + if !response.is_authenticated { + this.update(cx, |this, _| { + this.thread_state = ThreadState::Unauthenticated; + }) + .ok(); + return; + } + agent.clone().create_thread(cx).await + } + }; this.update_in(cx, |this, window, cx| { match result { Ok(thread) => { -<<<<<<< HEAD let subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event); -||||||| parent of 47b80cc740 (Show errors from ACP when requests error) - let subscription = cx.subscribe(&thread, |this, _, event, cx| { - let count = this.list_state.item_count(); - match event { - AcpThreadEvent::NewEntry => { - this.list_state.splice(count..count, 1); - } - AcpThreadEvent::EntryUpdated(index) => { - this.list_state.splice(*index..*index + 1, 1); - } - } - cx.notify(); - }); -======= - dbg!(&thread); - let subscription = cx.subscribe(&thread, |this, _, event, cx| { - let count = this.list_state.item_count(); - match event { - AcpThreadEvent::NewEntry => { - this.list_state.splice(count..count, 1); - } - AcpThreadEvent::EntryUpdated(index) => { - this.list_state.splice(*index..*index + 1, 1); - } - } - cx.notify(); - }); ->>>>>>> 47b80cc740 (Show errors from ACP when requests error) this.list_state .splice(0..0, thread.read(cx).entries().len()); @@ -210,7 +198,9 @@ impl AcpThreadView { fn thread(&self) -> Option<&Entity> { match &self.thread_state { ThreadState::Ready { thread, .. } => Some(thread), - ThreadState::Loading { .. } | ThreadState::LoadError(..) => None, + ThreadState::Loading { .. } + | ThreadState::LoadError(..) + | ThreadState::Unauthenticated => None, } } @@ -219,6 +209,7 @@ impl AcpThreadView { ThreadState::Ready { thread, .. } => thread.read(cx).title(), ThreadState::Loading { .. } => "Loading...".into(), ThreadState::LoadError(_) => "Failed to load".into(), + ThreadState::Unauthenticated => "Not authenticated".into(), } } @@ -357,6 +348,27 @@ impl AcpThreadView { } } + fn authenticate(&mut self, window: &mut Window, cx: &mut Context) { + let agent = self.agent.clone(); + + self.auth_task = Some(cx.spawn_in(window, async move |this, cx| { + let result = agent.authenticate().await; + + this.update_in(cx, |this, window, cx| { + if let Err(err) = result { + this.last_error = + Some(cx.new(|cx| { + Markdown::new(format!("Error: {err}").into(), None, None, cx) + })) + } else { + this.thread_state = Self::initial_state(agent, window, cx) + } + this.auth_task.take() + }) + .ok(); + })); + } + fn authorize_tool_call( &mut self, id: ToolCallId, @@ -920,6 +932,14 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::chat)) .h_full() .child(match &self.thread_state { + ThreadState::Unauthenticated => v_flex() + .p_2() + .flex_1() + .justify_end() + .child(Label::new("Not authenticated")) + .child(Button::new("sign-in", "Sign in via Gemini CLI").on_click( + cx.listener(|this, _, window, cx| this.authenticate(window, cx)), + )), ThreadState::Loading { .. } => v_flex() .p_2() .flex_1()