Compare commits

...

2 Commits

Author SHA1 Message Date
Richard Feldman
c6be76d927 Make the reminder prompt more aggro 2025-04-15 17:00:33 -04:00
Richard Feldman
9741a87046 Add reminder for Gemini models to try to force tool use. 2025-04-15 16:52:03 -04:00
5 changed files with 124 additions and 14 deletions

1
Cargo.lock generated
View File

@@ -11026,6 +11026,7 @@ dependencies = [
"paths",
"rope",
"serde",
"serde_json",
"text",
"util",
"uuid",

View File

@@ -0,0 +1,9 @@
The following tools are available for you to use: {{#each tools}}{{#if @index}}, {{/if}}{{this}}{{/each}}.
There is also a special tool named final_response which takes no input. Use this tool when you are completely finished,
and do not need to run any more tools because the original request has been satisfied.
You MUST respond with one of these tool uses. If you are completely done, then use the final_response tool.
If you respond with ANYTHING other than a tool use, I will know that you have disregarded my instructions and I will
be very disappointed, and will just prompt you once again to respond with a tool use. So you MUST respond with a
tool use RIGHT NOW NO MATTER WHAT.

View File

@@ -840,7 +840,7 @@ impl Thread {
request_kind: RequestKind,
cx: &mut Context<Self>,
) {
let mut request = self.to_completion_request(request_kind, cx);
let mut request = self.to_completion_request(request_kind, model.clone(), cx);
if model.supports_tools() {
request.tools = {
let mut tools = Vec::new();
@@ -882,6 +882,7 @@ impl Thread {
pub fn to_completion_request(
&self,
request_kind: RequestKind,
model: Arc<dyn LanguageModel>,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
@@ -954,6 +955,33 @@ impl Thread {
self.attached_tracked_files_state(&mut request.messages, cx);
// Gemini models need extremely strong encouragement to
// get them to actually use tools, so we add a reminder
// message at the end of each user request.
if let Some(last_user_message) = request
.messages
.iter_mut()
.rev()
.find(|msg| msg.role == Role::User)
{
if model.id().0.contains("gemini") {
let enabled_tools = self
.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.map(|tool| tool.name())
.collect::<Vec<_>>();
last_user_message
.content
.push(MessageContent::Text(gemini_reminder(
&self.prompt_builder,
enabled_tools,
)));
}
}
request
}
@@ -1208,15 +1236,17 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
let Some(ConfiguredModel { model, provider }) =
LanguageModelRegistry::read_global(cx).thread_summary_model()
else {
return;
};
if !model.provider.is_authenticated(cx) {
if !provider.is_authenticated(cx) {
return;
}
let mut request = self.to_completion_request(RequestKind::Summarize, cx);
let mut request = self.to_completion_request(RequestKind::Summarize, model.clone(), cx);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
@@ -1231,7 +1261,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.model.stream_completion_text(request, &cx);
let stream = model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
@@ -1282,7 +1312,7 @@ impl Thread {
return None;
}
let mut request = self.to_completion_request(RequestKind::Summarize, cx);
let mut request = self.to_completion_request(RequestKind::Summarize, model.clone(), cx);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
@@ -1344,7 +1374,12 @@ impl Thread {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = self.to_completion_request(RequestKind::Chat, cx);
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
let request = self.to_completion_request(RequestKind::Chat, model, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
.tool_use
@@ -1462,12 +1497,33 @@ impl Thread {
}
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
let model_registry = LanguageModelRegistry::read_global(cx);
let is_gemini = model_registry
.default_model()
.map_or(false, |model| model.model.id().0.contains("gemini"));
// Get the list of enabled tools if we're using Gemini
let mut message = "Here are the tool results.".to_string();
if is_gemini {
let enabled_tools = self
.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.map(|tool| tool.name())
.collect::<Vec<_>>();
message.push_str("\n\n");
message.push_str(&gemini_reminder(&self.prompt_builder, enabled_tools));
}
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
// responses that also don't have any content. We currently don't handle this case well,
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
&message,
Vec::new(),
None,
cx,
@@ -1963,6 +2019,12 @@ pub enum ThreadEvent {
impl EventEmitter<ThreadEvent> for Thread {}
pub fn gemini_reminder(prompt_builder: &prompt_store::PromptBuilder, tools: Vec<String>) -> String {
prompt_builder
.generate_gemini_reminder(tools)
.unwrap_or_default()
}
struct PendingCompletion {
id: usize,
_task: Task<()>,
@@ -2045,7 +2107,12 @@ fn main() {{
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
assert_eq!(request.messages.len(), 2);
@@ -2137,7 +2204,12 @@ fn main() {{
// Check entire request to make sure all contexts are properly included
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
// The request should contain all 3 messages
@@ -2189,7 +2261,12 @@ fn main() {{
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
assert_eq!(request.messages.len(), 2);
@@ -2209,7 +2286,12 @@ fn main() {{
// Check that both messages appear in the request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
assert_eq!(request.messages.len(), 3);
@@ -2251,7 +2333,12 @@ fn main() {{
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
// Make sure we don't have a stale file warning yet
@@ -2281,7 +2368,12 @@ fn main() {{
// Create a new request and check for the stale buffer warning
let new_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.unwrap()
.model
.clone();
thread.to_completion_request(RequestKind::Chat, model, cx)
});
// We should have a stale file warning as the last message

View File

@@ -28,6 +28,7 @@ parking_lot.workspace = true
paths.workspace = true
rope.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true
util.workspace = true
uuid.workspace = true

View File

@@ -7,6 +7,7 @@ use handlebars::{Handlebars, RenderError};
use language::{BufferSnapshot, LanguageName, Point};
use parking_lot::Mutex;
use serde::Serialize;
use serde_json;
use std::{
ops::Range,
path::{Path, PathBuf},
@@ -267,6 +268,12 @@ impl PromptBuilder {
.render("assistant_system_prompt", context)
}
pub fn generate_gemini_reminder(&self, tools: Vec<String>) -> Result<String, RenderError> {
self.handlebars
.lock()
.render("gemini_reminder", &serde_json::json!({"tools": tools}))
}
pub fn generate_inline_transformation_prompt(
&self,
user_prompt: String,