agent_ui: Implement favorite model selection

This commit is contained in:
Alexey Orlenko
2025-12-05 19:15:21 +01:00
parent 7ed5d42696
commit ce884443f1
12 changed files with 684 additions and 84 deletions

1
Cargo.lock generated
View File

@@ -291,6 +291,7 @@ dependencies = [
name = "agent_settings"
version = "0.1.0"
dependencies = [
"agent-client-protocol",
"anyhow",
"cloud_llm_client",
"collections",

View File

@@ -202,6 +202,12 @@ pub trait AgentModelSelector: 'static {
fn should_render_footer(&self) -> bool {
false
}
/// Whether this selector supports the favorites feature.
/// Only the native agent uses the model ID format that maps to settings.
fn supports_favorites(&self) -> bool {
false
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -239,6 +245,10 @@ impl AgentModelList {
AgentModelList::Grouped(groups) => groups.is_empty(),
}
}
pub fn is_flat(&self) -> bool {
matches!(self, AgentModelList::Flat(_))
}
}
#[cfg(feature = "test-support")]

View File

@@ -944,6 +944,10 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
fn should_render_footer(&self) -> bool {
true
}
fn supports_favorites(&self) -> bool {
true
}
}
impl acp_thread::AgentConnection for NativeAgentConnection {

View File

@@ -12,6 +12,7 @@ workspace = true
path = "src/agent_settings.rs"
[dependencies]
agent-client-protocol.workspace = true
anyhow.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true

View File

@@ -2,7 +2,8 @@ mod agent_profile;
use std::sync::Arc;
use collections::IndexMap;
use agent_client_protocol::ModelId;
use collections::{HashSet, IndexMap};
use gpui::{App, Pixels, px};
use language_model::LanguageModel;
use project::DisableAiSettings;
@@ -31,6 +32,8 @@ pub struct AgentSettings {
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub favorite_models_as_selections: Vec<LanguageModelSelection>,
pub favorite_models_as_ids: Arc<HashSet<ModelId>>,
pub default_profile: AgentProfileId,
pub default_view: DefaultAgentView,
pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
@@ -158,6 +161,16 @@ impl Settings for AgentSettings {
commit_message_model: agent.commit_message_model,
thread_summary_model: agent.thread_summary_model,
inline_alternatives: agent.inline_alternatives.unwrap_or_default(),
favorite_models_as_selections: agent.favorite_models,
favorite_models_as_ids: Arc::new(
content
.agent
.as_ref()
.iter()
.flat_map(|agent| &agent.favorite_models)
.map(|sel| ModelId::new(format!("{}/{}", sel.provider.0, sel.model)))
.collect(),
),
default_profile: AgentProfileId(agent.default_profile.unwrap()),
default_view: agent.default_view.unwrap(),
profiles: agent

View File

@@ -1,20 +1,24 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_client_protocol::ModelId;
use agent_servers::AgentServer;
use agent_settings::AgentSettings;
use anyhow::Result;
use collections::IndexMap;
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::FutureExt;
use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{
Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
};
use itertools::Itertools;
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use settings::{LanguageModelSelection, Settings, update_settings_file};
use ui::{
DocumentationAside, DocumentationEdge, DocumentationSide, IntoElement, KeyBinding, ListItem,
ListItemSpacing, prelude::*,
ListItemSpacing, Tooltip, prelude::*,
};
use util::ResultExt;
use zed_actions::agent::OpenSettings;
@@ -41,7 +45,42 @@ pub fn acp_model_selector(
enum AcpModelPickerEntry {
Separator(SharedString),
Model(AgentModelInfo),
Model(AgentModelInfo, AcpModelPickerEntryAction),
}
/// Corresponds to the action button shown on the model in the list.
/// `Unfavorite` and `RemoveFromFavorites` are semantically the same but
/// correspond to different icons.
#[derive(Copy, Clone)]
enum AcpModelPickerEntryAction {
Favorite,
Unfavorite,
RemoveFromFavorites,
}
impl AcpModelPickerEntryAction {
fn for_model_in_general_section(model: &AgentModelInfo, favorites: &HashSet<ModelId>) -> Self {
if favorites.contains(&model.id) {
Self::Unfavorite
} else {
Self::Favorite
}
}
fn icon_name(&self) -> IconName {
match self {
Self::Favorite => IconName::Star,
Self::Unfavorite => IconName::StarFilled,
Self::RemoveFromFavorites => IconName::Trash,
}
}
fn tooltip(&self) -> SharedString {
match self {
Self::Favorite => "Add to favorites".into(),
Self::Unfavorite | Self::RemoveFromFavorites => "Remove from favorites".into(),
}
}
}
pub struct AcpModelPickerDelegate {
@@ -143,7 +182,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
Some(AcpModelPickerEntry::Model(_)) => true,
Some(AcpModelPickerEntry::Model(_, _)) => true,
Some(AcpModelPickerEntry::Separator(_)) | None => false,
}
}
@@ -158,6 +197,12 @@ impl PickerDelegate for AcpModelPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let favorites = if self.selector.supports_favorites() {
AgentSettings::get_global(cx).favorite_models_as_ids.clone()
} else {
Default::default()
};
cx.spawn_in(window, async move |this, cx| {
let filtered_models = match this
.read_with(cx, |this, cx| {
@@ -174,7 +219,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries =
info_list_to_picker_entries(filtered_models).collect();
info_list_to_picker_entries(filtered_models, favorites).collect();
// Finds the currently selected model in the list
let new_index = this
.delegate
@@ -182,7 +227,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
.as_ref()
.and_then(|selected| {
this.delegate.filtered_entries.iter().position(|entry| {
if let AcpModelPickerEntry::Model(model_info) = entry {
if let AcpModelPickerEntry::Model(model_info, _) = entry {
model_info.id == selected.id
} else {
false
@@ -198,7 +243,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(AcpModelPickerEntry::Model(model_info)) =
if let Some(AcpModelPickerEntry::Model(model_info, _)) =
self.filtered_entries.get(self.selected_index)
{
if window.modifiers().secondary() {
@@ -258,7 +303,7 @@ impl PickerDelegate for AcpModelPickerDelegate {
)
.into_any_element(),
),
AcpModelPickerEntry::Model(model_info) => {
AcpModelPickerEntry::Model(model_info, action) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
let default_model = self.agent_server.default_model(cx);
let is_default = default_model.as_ref() == Some(&model_info.id);
@@ -269,6 +314,35 @@ impl PickerDelegate for AcpModelPickerDelegate {
Color::Muted
};
let handle_action_click = {
let fs = self.fs.clone();
let action = *action;
let model_id = model_info.id.clone();
move |cx: &App| {
let fs = fs.clone();
let model_id = model_id.0.as_ref();
let (provider, model) = model_id.split_once('/').unwrap_or(("", model_id));
let selection = LanguageModelSelection {
provider: provider.to_owned().into(),
model: model.to_owned(),
};
update_settings_file(fs, cx, move |settings, _| match action {
AcpModelPickerEntryAction::Favorite => settings
.agent
.get_or_insert_default()
.add_favorite_model(selection),
AcpModelPickerEntryAction::Unfavorite
| AcpModelPickerEntryAction::RemoveFromFavorites => settings
.agent
.get_or_insert_default()
.remove_favorite_model(&selection),
});
}
};
Some(
div()
.id(("model-picker-menu-child", ix))
@@ -307,7 +381,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
.color(Color::Accent)
.size(IconSize::Small),
)
})),
}))
.end_hover_slot(
div().pr_3().when(self.selector.supports_favorites(), |this| {
this.child(
IconButton::new(("toggle-favorite", ix), action.icon_name())
.icon_color(model_icon_color)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text(action.tooltip()))
.on_click(move |_, _, cx| handle_action_click(cx)),
)
}),
)
)
.into_any_element()
)
@@ -376,17 +461,63 @@ impl PickerDelegate for AcpModelPickerDelegate {
fn info_list_to_picker_entries(
model_list: AgentModelList,
favorites: Arc<HashSet<ModelId>>,
) -> impl Iterator<Item = AcpModelPickerEntry> {
match model_list {
AgentModelList::Flat(list) => {
itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
}
let all_models = match &model_list {
AgentModelList::Flat(list) => itertools::Either::Left(list.iter()),
AgentModelList::Grouped(index_map) => {
itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
.chain(models.into_iter().map(AcpModelPickerEntry::Model))
itertools::Either::Right(index_map.values().flatten())
}
};
let favorites_entries = all_models
.filter(|model_info| favorites.contains(&model_info.id))
.unique_by(|model_info| &model_info.id)
.map(|model_info| {
AcpModelPickerEntry::Model(
model_info.clone(),
AcpModelPickerEntryAction::RemoveFromFavorites,
)
})
.collect_vec();
let model_list_is_flat = model_list.is_flat();
let all_models_entries = match model_list {
AgentModelList::Flat(list) => {
itertools::Either::Left(list.into_iter().map(move |model_info| {
let action = AcpModelPickerEntryAction::for_model_in_general_section(
&model_info,
&favorites,
);
AcpModelPickerEntry::Model(model_info, action)
}))
}
AgentModelList::Grouped(index_map) => {
itertools::Either::Right(index_map.into_iter().flat_map(move |(group_name, models)| {
let favorites = favorites.clone();
std::iter::once(AcpModelPickerEntry::Separator(group_name.0)).chain(
models.into_iter().map(move |model_info| {
let action = AcpModelPickerEntryAction::for_model_in_general_section(
&model_info,
&favorites,
);
AcpModelPickerEntry::Model(model_info, action)
}),
)
}))
}
};
if favorites_entries.is_empty() {
itertools::Either::Left(all_models_entries)
} else {
itertools::Either::Right(
std::iter::once(AcpModelPickerEntry::Separator("Favorite".into()))
.chain(favorites_entries)
.chain(model_list_is_flat.then_some(AcpModelPickerEntry::Separator("All".into())))
.chain(all_models_entries),
)
}
}
@@ -511,6 +642,184 @@ mod tests {
}
}
fn create_favorites(models: Vec<&str>) -> Arc<HashSet<ModelId>> {
Arc::new(
models
.into_iter()
.map(|m| ModelId::new(m.to_string()))
.collect(),
)
}
fn get_entry_model_ids(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
entries
.iter()
.filter_map(|entry| match entry {
AcpModelPickerEntry::Model(info, _) => Some(info.id.0.as_ref()),
_ => None,
})
.collect()
}
fn get_entry_labels(entries: &[AcpModelPickerEntry]) -> Vec<&str> {
entries
.iter()
.map(|entry| match entry {
AcpModelPickerEntry::Model(info, _) => info.id.0.as_ref(),
AcpModelPickerEntry::Separator(s) => &s,
})
.collect()
}
#[gpui::test]
fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
let models = create_model_list(vec![
("zed", vec!["zed/claude", "zed/gemini"]),
("openai", vec!["openai/gpt-5"]),
]);
let favorites = create_favorites(vec!["zed/gemini"]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
assert!(matches!(
entries.first(),
Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
));
let model_ids = get_entry_model_ids(&entries);
assert_eq!(model_ids[0], "zed/gemini");
}
#[gpui::test]
fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
let models = create_model_list(vec![("zed", vec!["zed/claude", "zed/gemini"])]);
let favorites = create_favorites(vec![]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
assert!(matches!(
entries.first(),
Some(AcpModelPickerEntry::Separator(s)) if s == "zed"
));
}
#[gpui::test]
fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
let models = create_model_list(vec![
("zed", vec!["zed/claude", "zed/gemini"]),
("openai", vec!["openai/gpt-5"]),
]);
let favorites = create_favorites(vec!["zed/claude"]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
let mut in_favorites_section = false;
for entry in &entries {
match entry {
AcpModelPickerEntry::Separator(s) if s == "Favorite" => {
in_favorites_section = true;
}
AcpModelPickerEntry::Separator(_) => {
in_favorites_section = false;
}
AcpModelPickerEntry::Model(_, action) if in_favorites_section => {
assert!(matches!(
action,
AcpModelPickerEntryAction::RemoveFromFavorites
));
}
AcpModelPickerEntry::Model(info, action) if info.id.0.as_ref() == "zed/claude" => {
assert!(matches!(action, AcpModelPickerEntryAction::Unfavorite));
}
AcpModelPickerEntry::Model(_, action) => {
assert!(matches!(action, AcpModelPickerEntryAction::Favorite));
}
}
}
}
#[gpui::test]
fn test_favorites_appear_in_both_sections(_cx: &mut TestAppContext) {
let models = create_model_list(vec![
("zed", vec!["zed/claude", "zed/gemini"]),
("openai", vec!["openai/gpt-5", "openai/gpt-4"]),
]);
let favorites = create_favorites(vec!["zed/gemini", "openai/gpt-5"]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
let model_ids = get_entry_model_ids(&entries);
assert_eq!(model_ids[0], "zed/gemini");
assert_eq!(model_ids[1], "openai/gpt-5");
assert!(model_ids[2..].contains(&"zed/gemini"));
assert!(model_ids[2..].contains(&"openai/gpt-5"));
}
#[gpui::test]
fn test_favorites_are_not_duplicated_when_repeated_in_other_sections(_cx: &mut TestAppContext) {
let models = create_model_list(vec![
("Recommended", vec!["zed/claude", "anthropic/claude"]),
("Zed", vec!["zed/claude", "zed/gpt-5"]),
("Antropic", vec!["anthropic/claude"]),
("OpenAI", vec!["openai/gpt-5"]),
]);
let favorites = create_favorites(vec!["zed/claude"]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
let labels = get_entry_labels(&entries);
assert_eq!(
labels,
vec![
"Favorite",
"zed/claude",
"Recommended",
"zed/claude",
"anthropic/claude",
"Zed",
"zed/claude",
"zed/gpt-5",
"Antropic",
"anthropic/claude",
"OpenAI",
"openai/gpt-5"
]
);
}
#[gpui::test]
fn test_flat_model_list_with_favorites(_cx: &mut TestAppContext) {
let models = AgentModelList::Flat(vec![
acp_thread::AgentModelInfo {
id: acp::ModelId::new("zed/claude".to_string()),
name: "Claude".into(),
description: None,
icon: None,
},
acp_thread::AgentModelInfo {
id: acp::ModelId::new("zed/gemini".to_string()),
name: "Gemini".into(),
description: None,
icon: None,
},
]);
let favorites = create_favorites(vec!["zed/gemini"]);
let entries = info_list_to_picker_entries(models, favorites).collect_vec();
assert!(matches!(
entries.first(),
Some(AcpModelPickerEntry::Separator(s)) if s == "Favorite"
));
assert!(entries.iter().any(|e| matches!(
e,
AcpModelPickerEntry::Separator(s) if s == "All"
)));
}
#[gpui::test]
async fn test_fuzzy_match(cx: &mut TestAppContext) {
let models = create_model_list(vec![

View File

@@ -235,23 +235,27 @@ impl ManageProfilesModal {
})
}
},
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
let profile_id = profile_id.clone();
{
let fs = fs.clone();
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
let profile_id = profile_id.clone();
update_settings_file(fs.clone(), cx, move |settings, _cx| {
let agent_settings = settings.agent.get_or_insert_default();
if let Some(profiles) = agent_settings.profiles.as_mut() {
if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) {
profile.default_model = Some(LanguageModelSelection {
provider: LanguageModelProviderSetting(provider.clone()),
model: model_id.clone(),
});
update_settings_file(fs.clone(), cx, move |settings, _cx| {
let agent_settings = settings.agent.get_or_insert_default();
if let Some(profiles) = agent_settings.profiles.as_mut() {
if let Some(profile) = profiles.get_mut(profile_id.0.as_ref()) {
profile.default_model = Some(LanguageModelSelection {
provider: LanguageModelProviderSetting(provider.clone()),
model: model_id.clone(),
});
}
}
}
});
});
}
},
fs,
false, // Do not use popover styles for the model picker
self.focus_handle.clone(),
window,

View File

@@ -35,20 +35,24 @@ impl AgentModelSelector {
let model_context = model_usage_context.clone();
move |cx| model_context.configured_model(cx)
},
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match &model_usage_context {
ModelUsageContext::InlineAssistant => {
update_settings_file(fs.clone(), cx, move |settings, _cx| {
settings
.agent
.get_or_insert_default()
.set_inline_assistant_model(provider.clone(), model_id);
});
{
let fs = fs.clone();
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match &model_usage_context {
ModelUsageContext::InlineAssistant => {
update_settings_file(fs.clone(), cx, move |settings, _cx| {
settings
.agent
.get_or_insert_default()
.set_inline_assistant_model(provider.clone(), model_id);
});
}
}
}
},
fs,
true, // Use popover styles for picker
focus_handle_clone,
window,

View File

@@ -448,6 +448,8 @@ mod tests {
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: vec![],
favorite_models_as_selections: vec![],
favorite_models_as_ids: Arc::new(Default::default()),
default_profile: AgentProfileId::default(),
default_view: DefaultAgentView::Thread,
profiles: Default::default(),

View File

@@ -1,17 +1,20 @@
use std::{cmp::Reverse, sync::Arc};
use collections::IndexMap;
use agent_settings::AgentSettings;
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProvider,
LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
use settings::{LanguageModelSelection, Settings, update_settings_file};
use ui::{KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
use zed_actions::agent::OpenSettings;
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
@@ -22,6 +25,7 @@ pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
pub fn language_model_selector(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
fs: Arc<dyn Fs>,
popover_styles: bool,
focus_handle: FocusHandle,
window: &mut Window,
@@ -30,6 +34,7 @@ pub fn language_model_selector(
let delegate = LanguageModelPickerDelegate::new(
get_active_model,
on_model_changed,
fs,
popover_styles,
focus_handle,
window,
@@ -47,7 +52,17 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let lm_registry = LanguageModelRegistry::global(cx).read(cx);
let providers = lm_registry.providers();
let mut favorites_index = FavoritesIndex::default();
for sel in &AgentSettings::get_global(cx).favorite_models_as_selections {
favorites_index
.entry(sel.provider.0.clone().into())
.or_default()
.insert(sel.model.clone().into());
}
let recommended = providers
.iter()
@@ -55,10 +70,7 @@ fn all_models(cx: &App) -> GroupedModels {
provider
.recommended_models(cx)
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
})
.map(|model| ModelInfo::new(&**provider, model, &favorites_index))
})
.collect();
@@ -68,20 +80,38 @@ fn all_models(cx: &App) -> GroupedModels {
provider
.provided_models(cx)
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
})
.map(|model| ModelInfo::new(&**provider, model, &favorites_index))
})
.collect();
GroupedModels::new(all, recommended)
}
type FavoritesIndex = HashMap<LanguageModelProviderId, HashSet<LanguageModelId>>;
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
is_favorite: bool,
}
impl ModelInfo {
fn new(
provider: &dyn LanguageModelProvider,
model: Arc<dyn LanguageModel>,
favorites_index: &FavoritesIndex,
) -> Self {
let is_favorite = favorites_index
.get(&provider.id())
.map_or(false, |set| set.contains(&model.id()));
Self {
model,
icon: provider.icon(),
is_favorite,
}
}
}
pub struct LanguageModelPickerDelegate {
@@ -94,12 +124,14 @@ pub struct LanguageModelPickerDelegate {
_subscriptions: Vec<Subscription>,
popover_styles: bool,
focus_handle: FocusHandle,
fs: Arc<dyn Fs>,
}
impl LanguageModelPickerDelegate {
fn new(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
fs: Arc<dyn Fs>,
popover_styles: bool,
focus_handle: FocusHandle,
window: &mut Window,
@@ -136,6 +168,7 @@ impl LanguageModelPickerDelegate {
)],
popover_styles,
focus_handle,
fs,
}
}
@@ -146,7 +179,7 @@ impl LanguageModelPickerDelegate {
entries
.iter()
.position(|entry| {
if let LanguageModelPickerEntry::Model(model) = entry {
if let LanguageModelPickerEntry::Model(model, _) = entry {
active_model
.as_ref()
.map(|active_model| {
@@ -217,12 +250,19 @@ impl LanguageModelPickerDelegate {
}
struct GroupedModels {
favorites: Vec<ModelInfo>,
recommended: Vec<ModelInfo>,
all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
}
impl GroupedModels {
pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
let favorites = all
.iter()
.filter(|info| info.is_favorite)
.cloned()
.collect();
let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in all {
let provider = model.model.provider_id();
@@ -234,6 +274,7 @@ impl GroupedModels {
}
Self {
favorites,
recommended,
all: all_by_provider,
}
@@ -242,13 +283,24 @@ impl GroupedModels {
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
let mut entries = Vec::new();
if !self.favorites.is_empty() {
entries.push(LanguageModelPickerEntry::Separator("Favorite".into()));
entries.extend(self.favorites.iter().map(|info| {
LanguageModelPickerEntry::Model(
info.clone(),
LanguageModelPickerEntryAction::RemoveFromFavorites,
)
}));
}
if !self.recommended.is_empty() {
entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
entries.extend(
self.recommended
.iter()
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
);
entries.extend(self.recommended.iter().map(|info| {
LanguageModelPickerEntry::Model(
info.clone(),
LanguageModelPickerEntryAction::for_model_in_general_section(&info),
)
}));
}
for models in self.all.values() {
@@ -258,21 +310,57 @@ impl GroupedModels {
entries.push(LanguageModelPickerEntry::Separator(
models[0].model.provider_name().0,
));
entries.extend(
models
.iter()
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
);
entries.extend(models.iter().map(|info| {
LanguageModelPickerEntry::Model(
info.clone(),
LanguageModelPickerEntryAction::for_model_in_general_section(&info),
)
}));
}
entries
}
}
enum LanguageModelPickerEntry {
Model(ModelInfo),
Model(ModelInfo, LanguageModelPickerEntryAction),
Separator(SharedString),
}
/// Corresponds to the action button shown on the model in the list.
/// `Unfavorite` and `RemoveFromFavorites` are semantically the same but
/// correspond to different icons.
#[derive(Copy, Clone)]
enum LanguageModelPickerEntryAction {
Favorite,
Unfavorite,
RemoveFromFavorites,
}
impl LanguageModelPickerEntryAction {
fn for_model_in_general_section(model: &ModelInfo) -> Self {
if model.is_favorite {
Self::Unfavorite
} else {
Self::Favorite
}
}
fn icon_name(&self) -> IconName {
match self {
Self::Favorite => IconName::Star,
Self::Unfavorite => IconName::StarFilled,
Self::RemoveFromFavorites => IconName::Trash,
}
}
fn tooltip(&self) -> SharedString {
match self {
Self::Favorite => "Add to favorites".into(),
Self::Unfavorite | Self::RemoveFromFavorites => "Remove from favorites".into(),
}
}
}
struct ModelMatcher {
models: Vec<ModelInfo>,
bg_executor: BackgroundExecutor,
@@ -369,7 +457,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
Some(LanguageModelPickerEntry::Model(_)) => true,
Some(LanguageModelPickerEntry::Model(_, _)) => true,
Some(LanguageModelPickerEntry::Separator(_)) | None => false,
}
}
@@ -439,7 +527,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(LanguageModelPickerEntry::Model(model_info)) =
if let Some(LanguageModelPickerEntry::Model(model_info, _)) =
self.filtered_entries.get(self.selected_index)
{
let model = model_info.model.clone();
@@ -481,7 +569,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
)
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
LanguageModelPickerEntry::Model(model_info, action) => {
let active_model = (self.get_active_model)(cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());
@@ -495,6 +583,30 @@ impl PickerDelegate for LanguageModelPickerDelegate {
Color::Muted
};
let handle_action_click = {
let fs = self.fs.clone();
let action = *action;
let model = LanguageModelSelection {
provider: model_info.model.provider_id().to_string().into(),
model: model_info.model.id().0.to_string(),
};
move |cx: &App| {
let fs = fs.clone();
let model = model.clone();
update_settings_file(fs, cx, move |settings, _| match action {
LanguageModelPickerEntryAction::Favorite => settings
.agent
.get_or_insert_default()
.add_favorite_model(model),
LanguageModelPickerEntryAction::Unfavorite
| LanguageModelPickerEntryAction::RemoveFromFavorites => settings
.agent
.get_or_insert_default()
.remove_favorite_model(&model),
});
}
};
Some(
ListItem::new(ix)
.inset(true)
@@ -518,6 +630,15 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.size(IconSize::Small),
)
}))
.end_hover_slot(
div().pr_3().child(
IconButton::new(("toggle-favorite", ix), action.icon_name())
.icon_color(model_icon_color)
.icon_size(IconSize::Small)
.tooltip(Tooltip::text(action.tooltip()))
.on_click(move |_, _, cx| handle_action_click(cx)),
),
)
.into_any_element(),
)
}
@@ -653,11 +774,24 @@ mod tests {
}
fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
create_models_with_favorites(model_specs, vec![])
}
fn create_models_with_favorites(
model_specs: Vec<(&str, &str)>,
favorites: Vec<(&str, &str)>,
) -> Vec<ModelInfo> {
model_specs
.into_iter()
.map(|(provider, name)| ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
.map(|(provider, name)| {
let is_favorite = favorites
.iter()
.any(|(fav_provider, fav_name)| *fav_provider == provider && *fav_name == name);
ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
is_favorite,
}
})
.collect()
}
@@ -795,4 +929,105 @@ mod tests {
vec!["zed/claude", "zed/gemini", "copilot/claude"],
);
}
#[gpui::test]
fn test_favorites_section_appears_when_favorites_exist(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models_with_favorites(
vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
vec![("zed", "gemini")],
);
let grouped_models = GroupedModels::new(all_models, recommended_models);
let entries = grouped_models.entries();
assert!(matches!(
entries.first(),
Some(LanguageModelPickerEntry::Separator(s)) if s == "Favorite"
));
assert_models_eq(grouped_models.favorites, vec!["zed/gemini"]);
}
#[gpui::test]
fn test_no_favorites_section_when_no_favorites(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models(vec![("zed", "claude"), ("zed", "gemini")]);
let grouped_models = GroupedModels::new(all_models, recommended_models);
let entries = grouped_models.entries();
assert!(matches!(
entries.first(),
Some(LanguageModelPickerEntry::Separator(s)) if s == "Recommended"
));
assert!(grouped_models.favorites.is_empty());
}
#[gpui::test]
fn test_models_have_correct_actions(_cx: &mut TestAppContext) {
let recommended_models =
create_models_with_favorites(vec![("zed", "claude")], vec![("zed", "claude")]);
let all_models = create_models_with_favorites(
vec![("zed", "claude"), ("zed", "gemini"), ("openai", "gpt-4")],
vec![("zed", "claude")],
);
let grouped_models = GroupedModels::new(all_models, recommended_models);
let entries = grouped_models.entries();
let mut in_favorites_section = false;
for entry in &entries {
match entry {
LanguageModelPickerEntry::Separator(s) if s == "Favorite" => {
in_favorites_section = true;
}
LanguageModelPickerEntry::Separator(_) => {
in_favorites_section = false;
}
LanguageModelPickerEntry::Model(_, action) if in_favorites_section => {
assert!(matches!(
action,
LanguageModelPickerEntryAction::RemoveFromFavorites
));
}
LanguageModelPickerEntry::Model(info, action)
if !in_favorites_section && info.model.telemetry_id() == "zed/claude" =>
{
assert!(matches!(action, LanguageModelPickerEntryAction::Unfavorite));
}
LanguageModelPickerEntry::Model(_, action) => {
assert!(matches!(action, LanguageModelPickerEntryAction::Favorite));
}
}
}
}
#[gpui::test]
fn test_favorites_appear_in_other_sections(_cx: &mut TestAppContext) {
let favorites = vec![("zed", "gemini"), ("openai", "gpt-4")];
let recommended_models =
create_models_with_favorites(vec![("zed", "claude")], favorites.clone());
let all_models = create_models_with_favorites(
vec![
("zed", "claude"),
("zed", "gemini"),
("openai", "gpt-4"),
("openai", "gpt-3.5"),
],
favorites,
);
let grouped_models = GroupedModels::new(all_models, recommended_models);
assert_models_eq(grouped_models.favorites, vec!["zed/gemini", "openai/gpt-4"]);
assert_models_eq(grouped_models.recommended, vec!["zed/claude"]);
assert_models_eq(
grouped_models.all.values().flatten().cloned().collect(),
vec!["zed/claude", "zed/gemini", "openai/gpt-4", "openai/gpt-3.5"],
);
}
}

View File

@@ -304,18 +304,22 @@ impl TextThreadEditor {
language_model_selector: cx.new(|cx| {
language_model_selector(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file(fs.clone(), cx, move |settings, _| {
let provider = model.provider_id().0.to_string();
let model = model.id().0.to_string();
settings.agent.get_or_insert_default().set_model(
LanguageModelSelection {
provider: LanguageModelProviderSetting(provider),
model,
},
)
});
{
let fs = fs.clone();
move |model, cx| {
update_settings_file(fs.clone(), cx, move |settings, _| {
let provider = model.provider_id().0.to_string();
let model = model.id().0.to_string();
settings.agent.get_or_insert_default().set_model(
LanguageModelSelection {
provider: LanguageModelProviderSetting(provider),
model,
},
)
});
}
},
fs,
true, // Use popover styles for picker
focus_handle,
window,

View File

@@ -34,6 +34,9 @@ pub struct AgentSettingsContent {
pub default_height: Option<f32>,
/// The default model to use when creating new chats and for other features when a specific model is not specified.
pub default_model: Option<LanguageModelSelection>,
/// Favorite models to show at the top of the model selector.
#[serde(default)]
pub favorite_models: Vec<LanguageModelSelection>,
/// Model to use for the inline assistant. Defaults to default_model when not specified.
pub inline_assistant_model: Option<LanguageModelSelection>,
/// Model to use for generating git commit messages. Defaults to default_model when not specified.
@@ -163,6 +166,16 @@ impl AgentSettingsContent {
pub fn set_profile(&mut self, profile_id: Arc<str>) {
self.default_profile = Some(profile_id);
}
pub fn add_favorite_model(&mut self, model: LanguageModelSelection) {
if !self.favorite_models.contains(&model) {
self.favorite_models.push(model);
}
}
pub fn remove_favorite_model(&mut self, model: &LanguageModelSelection) {
self.favorite_models.retain(|m| m != model);
}
}
#[with_fallible_options]