Compare commits
2 Commits
v0.211.6
...
gemini-rem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6be76d927 | ||
|
|
9741a87046 |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -11026,6 +11026,7 @@ dependencies = [
|
||||
"paths",
|
||||
"rope",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text",
|
||||
"util",
|
||||
"uuid",
|
||||
|
||||
9
assets/prompts/gemini_reminder.hbs
Normal file
9
assets/prompts/gemini_reminder.hbs
Normal file
@@ -0,0 +1,9 @@
|
||||
The following tools are available for you to use: {{#each tools}}{{#if @index}}, {{/if}}{{this}}{{/each}}.
|
||||
|
||||
There is also a special tool named final_response which takes no input. Use this tool when you are completely finished,
|
||||
and do not need to run any more tools because the original request has been satisfied.
|
||||
|
||||
You MUST respond with one of these tool uses. If you are completely done, then use the final_response tool.
|
||||
If you respond with ANYTHING other than a tool use, I will know that you have disregarded my instructions and I will
|
||||
be very disappointed, and will just prompt you once again to respond with a tool use. So you MUST respond with a
|
||||
tool use RIGHT NOW NO MATTER WHAT.
|
||||
@@ -840,7 +840,7 @@ impl Thread {
|
||||
request_kind: RequestKind,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let mut request = self.to_completion_request(request_kind, cx);
|
||||
let mut request = self.to_completion_request(request_kind, model.clone(), cx);
|
||||
if model.supports_tools() {
|
||||
request.tools = {
|
||||
let mut tools = Vec::new();
|
||||
@@ -882,6 +882,7 @@ impl Thread {
|
||||
pub fn to_completion_request(
|
||||
&self,
|
||||
request_kind: RequestKind,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &App,
|
||||
) -> LanguageModelRequest {
|
||||
let mut request = LanguageModelRequest {
|
||||
@@ -954,6 +955,33 @@ impl Thread {
|
||||
|
||||
self.attached_tracked_files_state(&mut request.messages, cx);
|
||||
|
||||
// Gemini models need extremely strong encouragement to
|
||||
// get them to actually use tools, so we add a reminder
|
||||
// message at the end of each user request.
|
||||
if let Some(last_user_message) = request
|
||||
.messages
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|msg| msg.role == Role::User)
|
||||
{
|
||||
if model.id().0.contains("gemini") {
|
||||
let enabled_tools = self
|
||||
.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
last_user_message
|
||||
.content
|
||||
.push(MessageContent::Text(gemini_reminder(
|
||||
&self.prompt_builder,
|
||||
enabled_tools,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
@@ -1208,15 +1236,17 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn summarize(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
|
||||
let Some(ConfiguredModel { model, provider }) =
|
||||
LanguageModelRegistry::read_global(cx).thread_summary_model()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
if !provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(RequestKind::Summarize, cx);
|
||||
let mut request = self.to_completion_request(RequestKind::Summarize, model.clone(), cx);
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![
|
||||
@@ -1231,7 +1261,7 @@ impl Thread {
|
||||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let stream = model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
let mut new_summary = String::new();
|
||||
@@ -1282,7 +1312,7 @@ impl Thread {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request = self.to_completion_request(RequestKind::Summarize, cx);
|
||||
let mut request = self.to_completion_request(RequestKind::Summarize, model.clone(), cx);
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
@@ -1344,7 +1374,12 @@ impl Thread {
|
||||
|
||||
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
|
||||
self.auto_capture_telemetry(cx);
|
||||
let request = self.to_completion_request(RequestKind::Chat, cx);
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
let request = self.to_completion_request(RequestKind::Chat, model, cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
let pending_tool_uses = self
|
||||
.tool_use
|
||||
@@ -1462,12 +1497,33 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let is_gemini = model_registry
|
||||
.default_model()
|
||||
.map_or(false, |model| model.model.id().0.contains("gemini"));
|
||||
|
||||
// Get the list of enabled tools if we're using Gemini
|
||||
let mut message = "Here are the tool results.".to_string();
|
||||
|
||||
if is_gemini {
|
||||
let enabled_tools = self
|
||||
.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
message.push_str("\n\n");
|
||||
message.push_str(&gemini_reminder(&self.prompt_builder, enabled_tools));
|
||||
}
|
||||
|
||||
// Insert a user message to contain the tool results.
|
||||
self.insert_user_message(
|
||||
// TODO: Sending up a user message without any content results in the model sending back
|
||||
// responses that also don't have any content. We currently don't handle this case well,
|
||||
// so for now we provide some text to keep the model on track.
|
||||
"Here are the tool results.",
|
||||
&message,
|
||||
Vec::new(),
|
||||
None,
|
||||
cx,
|
||||
@@ -1963,6 +2019,12 @@ pub enum ThreadEvent {
|
||||
|
||||
impl EventEmitter<ThreadEvent> for Thread {}
|
||||
|
||||
pub fn gemini_reminder(prompt_builder: &prompt_store::PromptBuilder, tools: Vec<String>) -> String {
|
||||
prompt_builder
|
||||
.generate_gemini_reminder(tools)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
struct PendingCompletion {
|
||||
id: usize,
|
||||
_task: Task<()>,
|
||||
@@ -2045,7 +2107,12 @@ fn main() {{
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
@@ -2137,7 +2204,12 @@ fn main() {{
|
||||
|
||||
// Check entire request to make sure all contexts are properly included
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
// The request should contain all 3 messages
|
||||
@@ -2189,7 +2261,12 @@ fn main() {{
|
||||
|
||||
// Check message in request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 2);
|
||||
@@ -2209,7 +2286,12 @@ fn main() {{
|
||||
|
||||
// Check that both messages appear in the request
|
||||
let request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
assert_eq!(request.messages.len(), 3);
|
||||
@@ -2251,7 +2333,12 @@ fn main() {{
|
||||
|
||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||
let initial_request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
// Make sure we don't have a stale file warning yet
|
||||
@@ -2281,7 +2368,12 @@ fn main() {{
|
||||
|
||||
// Create a new request and check for the stale buffer warning
|
||||
let new_request = thread.read_with(cx, |thread, cx| {
|
||||
thread.to_completion_request(RequestKind::Chat, cx)
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
thread.to_completion_request(RequestKind::Chat, model, cx)
|
||||
});
|
||||
|
||||
// We should have a stale file warning as the last message
|
||||
|
||||
@@ -28,6 +28,7 @@ parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
rope.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
text.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -7,6 +7,7 @@ use handlebars::{Handlebars, RenderError};
|
||||
use language::{BufferSnapshot, LanguageName, Point};
|
||||
use parking_lot::Mutex;
|
||||
use serde::Serialize;
|
||||
use serde_json;
|
||||
use std::{
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
@@ -267,6 +268,12 @@ impl PromptBuilder {
|
||||
.render("assistant_system_prompt", context)
|
||||
}
|
||||
|
||||
pub fn generate_gemini_reminder(&self, tools: Vec<String>) -> Result<String, RenderError> {
|
||||
self.handlebars
|
||||
.lock()
|
||||
.render("gemini_reminder", &serde_json::json!({"tools": tools}))
|
||||
}
|
||||
|
||||
pub fn generate_inline_transformation_prompt(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
|
||||
Reference in New Issue
Block a user