Compare commits

...

12 Commits

Author SHA1 Message Date
Piotr Osiewicz
819d47b1d2 Fix deser errors 2024-08-06 15:26:21 +02:00
Piotr Osiewicz
3f487c11d8 Fix clippy violation (unused import) 2024-08-06 13:37:05 +02:00
Piotr Osiewicz
f72443dd8f Merge branch 'main' into tools-in-other-models 2024-08-06 13:17:40 +02:00
Piotr Osiewicz
e21d2cf995 WIP: Google AI 2024-08-05 21:45:55 +02:00
Piotr Osiewicz
c3a81ce44f Remove stray dbg 2024-08-05 17:16:50 +02:00
Piotr Osiewicz
8acee44091 Remove unused imports 2024-08-05 17:14:32 +02:00
Piotr Osiewicz
5bcba222d9 Add support for OpenAI tool calling in cloud provider 2024-08-05 17:07:49 +02:00
Piotr Osiewicz
bcc50b4fa8 Add Tool calling for ollama 2024-08-05 16:10:26 +02:00
Piotr Osiewicz
5f326752c5 Remove unused imports 2024-08-05 14:43:13 +02:00
Piotr Osiewicz
52a4bc942a assistant: Add tool calling for OpenAI 2024-08-05 14:39:36 +02:00
Piotr Osiewicz
6d82c3f4db Compilable state with tool calling (a wrong one, but still) 2024-08-05 12:58:12 +02:00
Piotr Osiewicz
63d9ed3e88 WIP: start implementing tool calling for openai models 2024-08-05 12:41:42 +02:00
9 changed files with 477 additions and 92 deletions

View File

@@ -27,6 +27,7 @@ pub async fn stream_generate_content(
match line {
Ok(line) => {
if let Some(line) = line.strip_prefix("data: ") {
dbg!(&line);
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
@@ -88,6 +89,37 @@ pub enum Task {
#[serde(rename = "batchEmbedContents")]
BatchEmbedContents,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Tool {
FunctionDeclarations(Vec<FunctionDeclaration>),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Mode {
Auto,
Any,
None,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct FunctionCallingConfig {
pub mode: Mode,
pub allowed_function_names: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolConfig {
FunctionCallingConfig(FunctionCallingConfig),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -95,6 +127,9 @@ pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
pub contents: Vec<Content>,
#[serde(default)]
pub tools: Vec<Tool>,
pub tool_config: Option<ToolConfig>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
}
@@ -130,12 +165,23 @@ pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
InlineDataPart(InlineDataPart),
FunctionCall(FunctionCallPart),
}
#[derive(Debug, Serialize, Deserialize)]

View File

@@ -17,7 +17,10 @@ pub(crate) use rate_limiter::*;
pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
use schemars::{
r#gen::{SchemaGenerator, SchemaSettings},
schema_for, JsonSchema,
};
use serde::de::DeserializeOwned;
use std::{future::Future, sync::Arc};
use ui::IconName;
@@ -76,6 +79,11 @@ pub trait LanguageModel: Send + Sync {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>>;
/// Specify how a schema should be generated for tool calling.
fn schema_settings(&self) -> SchemaSettings {
SchemaSettings::draft07()
}
#[cfg(any(test, feature = "test-support"))]
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
unimplemented!()
@@ -88,7 +96,9 @@ impl dyn LanguageModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let mut settings = self.schema_settings();
let schema = schemars::gen::SchemaGenerator::new(settings).into_root_schema_for::<T>();
let schema_json = serde_json::to_value(&schema).unwrap();
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {

View File

@@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anyhow::{anyhow, Context as _, Result};
use anyhow::{anyhow, bail, Context as _, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
@@ -634,8 +634,72 @@ impl LanguageModel for CloudLanguageModel {
})
.boxed()
}
CloudModel::OpenAi(_) => {
future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone();
let mut function = open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = open_ai::ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()

View File

@@ -1,17 +1,21 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use google_ai::stream_generate_content;
use google_ai::{
stream_generate_content, FunctionCallPart, FunctionCallingConfig, FunctionDeclaration,
GenerateContentRequest, GenerateContentResponse, Mode, Part, Tool, ToolConfig,
};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
View, WhiteSpace,
};
use http_client::HttpClient;
use schemars::JsonSchema;
use schemars::{r#gen::SchemaSettings, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Indicator};
@@ -216,6 +220,34 @@ pub struct GoogleLanguageModel {
rate_limiter: RateLimiter,
}
impl GoogleLanguageModel {
fn stream_completion(
&self,
request: GenerateContentRequest,
cx: &AsyncAppContext,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(events)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for GoogleLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -276,34 +308,66 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_google(self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
});
async move { Ok(future.await?.boxed()) }.boxed()
let completions = self.stream_completion(request, cx);
async move { Ok(google_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
fn schema_settings(&self) -> schemars::r#gen::SchemaSettings {
let mut schema = SchemaSettings::openapi3();
schema.inline_subschemas = true;
schema.meta_schema.take();
schema
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
mut schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
let mut request = request.into_google(self.model.id().into());
if let Some(schema) = schema.as_object_mut() {
schema.remove("title");
}
let function = FunctionDeclaration {
name: tool_name.clone(),
description: Some(tool_description),
parameters: schema,
};
request.tool_config = Some(ToolConfig::FunctionCallingConfig(FunctionCallingConfig {
mode: Mode::Any,
allowed_function_names: vec![tool_name.clone()],
}));
request.tools = vec![Tool::FunctionDeclarations(vec![function])];
dbg!(&serde_json::to_string(&request).unwrap());
let response = self.stream_completion(request, cx);
self.rate_limiter
.run(async move {
let mut response = response.await?;
while let Some(part) = response.next().await {
let Some(part) = part.log_err() else {
continue;
};
let Some(candidates) = part.candidates else {
continue;
};
for choice in candidates {
for part in choice.content.parts {
if let Part::FunctionCall(FunctionCallPart { function_call }) = part {
if function_call.name == tool_name.as_str() {
return Ok(function_call.args);
}
}
}
}
}
bail!("tool not used");
})
.boxed()
}
}

View File

@@ -1,12 +1,14 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
@@ -184,6 +186,7 @@ impl OllamaLanguageModel {
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.content,
@@ -198,8 +201,25 @@ impl OllamaLanguageModel {
temperature: Some(request.temperature),
..Default::default()
}),
tools: vec![],
}
}
fn request_completion(
&self,
request: ChatRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
}
}
impl LanguageModel for OllamaLanguageModel {
@@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content } => content,
ChatMessage::Assistant { content, .. } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
@@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel {
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(schema),
};
let tools = vec![OllamaTool::Function { function }];
let request = self.to_ollama_request(request).with_tools(tools);
let response = self.request_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
let ChatMessage::Assistant {
tool_calls,
content,
} = response.message
else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
return Ok(function.arguments);
}
}
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
// Parse content as arguments.
return Ok(args);
} else {
bail!("assistant message does not have any tool calls");
};
bail!("tool not used")
})
.boxed()
}
}

View File

@@ -1,4 +1,4 @@
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -7,11 +7,13 @@ use gpui::{
View, WhiteSpace,
};
use http_client::HttpClient;
use open_ai::stream_completion;
use open_ai::{
stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Indicator};
@@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel {
request_limiter: RateLimiter,
}
impl OpenAiLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_open_ai(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
http_client.as_ref(),
&api_url,
&api_key,
request,
low_speed_timeout,
);
let response = request.await?;
Ok(open_ai::extract_text_from_events(response).boxed())
});
async move { Ok(future.await?.boxed()) }.boxed()
let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
let mut request = request.into_open_ai(self.model.id().into());
let mut function = FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(schema);
request.tools = vec![ToolDefinition::Function { function }];
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let mut response = response.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
}

View File

@@ -69,6 +69,8 @@ impl LanguageModelRequest {
top_k: None,
}),
safety_settings: None,
tools: vec![],
tool_config: None,
}
}

View File

@@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -94,22 +95,63 @@ impl Model {
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant { content: String },
User { content: String },
System { content: String },
Assistant {
content: String,
tool_calls: Option<Vec<OllamaToolCall>>,
},
User {
content: String,
},
System {
content: String,
},
}
#[derive(Serialize)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OllamaToolCall {
Function(OllamaFunctionCall),
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: Value,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct OllamaFunctionTool {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum OllamaTool {
Function { function: OllamaFunctionTool },
}
#[derive(Serialize, Debug)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
pub keep_alive: KeepAlive,
pub options: Option<ChatOptions>,
pub tools: Vec<OllamaTool>,
}
impl ChatRequest {
pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
self.stream = false;
self.tools = tools;
self
}
}
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
#[derive(Serialize, Default)]
#[derive(Serialize, Default, Debug)]
pub struct ChatOptions {
pub num_ctx: Option<usize>,
pub num_predict: Option<isize>,
@@ -118,7 +160,7 @@ pub struct ChatOptions {
pub top_p: Option<f32>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct ChatResponseDelta {
#[allow(unused)]
pub model: String,
@@ -162,6 +204,38 @@ pub struct ModelDetails {
pub quantization_level: String,
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
request: ChatRequest,
) -> Result<ChatResponseDelta> {
let uri = format!("{api_url}/api/chat");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
}
}
pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,

View File

@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter;
@@ -121,25 +121,34 @@ pub struct Request {
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Map<String, Value>>,
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
Required,
None,
Other(ToolDefinition),
}
#[derive(Deserialize, Serialize, Debug)]
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {