Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
caf46d8ddf | ||
|
|
808d262c3a | ||
|
|
944634e98c | ||
|
|
28be655495 | ||
|
|
935b16db1b | ||
|
|
f5f2f10e13 | ||
|
|
338dd62e0e | ||
|
|
f2cf745c3d | ||
|
|
2493e42564 | ||
|
|
53d14b32e2 |
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -3650,6 +3650,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"collections",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
@@ -3658,10 +3659,12 @@ dependencies = [
|
||||
"net",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
"rand 0.9.2",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"sha2",
|
||||
"slotmap",
|
||||
"smol",
|
||||
"tempfile",
|
||||
|
||||
@@ -17,6 +17,7 @@ test-support = ["gpui/test-support"]
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
base64.workspace =true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
@@ -25,16 +26,18 @@ log.workspace = true
|
||||
net.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
schemars.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sha2.workspace = true
|
||||
slotmap.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
terminal.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
terminal.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -6,6 +6,8 @@ pub mod test;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
pub use transport::UnauthorizedError;
|
||||
|
||||
use collections::HashMap;
|
||||
use http_client::HttpClient;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -4,7 +4,7 @@ mod stdio_transport;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::{error::Error, fmt, pin::Pin};
|
||||
|
||||
pub use http::*;
|
||||
pub use stdio_transport::*;
|
||||
@@ -15,3 +15,16 @@ pub trait Transport: Send + Sync {
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UnauthorizedError {
|
||||
pub www_authenticate_header: Option<String>,
|
||||
}
|
||||
|
||||
impl Error for UnauthorizedError {}
|
||||
|
||||
impl fmt::Display for UnauthorizedError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthorized")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
mod auth;
|
||||
mod www_authenticate;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
@@ -8,7 +11,10 @@ use parking_lot::Mutex as SyncMutex;
|
||||
use smol::channel;
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::transport::{
|
||||
Transport, UnauthorizedError,
|
||||
http::{auth::OAuthClient, www_authenticate::WwwAuthenticate},
|
||||
};
|
||||
|
||||
// Constants from MCP spec
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
@@ -130,6 +136,16 @@ impl HttpTransport {
|
||||
// Accepted - notification acknowledged, no response needed
|
||||
log::debug!("Notification accepted");
|
||||
}
|
||||
status if status.as_u16() == 401 => {
|
||||
let www_authenticate_header = response
|
||||
.headers()
|
||||
.get("WWW-Authenticate")
|
||||
.and_then(|value| Some(value.to_str().ok()?.to_string()));
|
||||
|
||||
anyhow::bail!(UnauthorizedError {
|
||||
www_authenticate_header
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let mut error_body = String::new();
|
||||
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;
|
||||
|
||||
675
crates/context_server/src/transport/http/auth.rs
Normal file
675
crates/context_server/src/transport/http/auth.rs
Normal file
@@ -0,0 +1,675 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
error::Error,
|
||||
fmt::{self, Display},
|
||||
ops::Deref,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use base64::Engine as _;
|
||||
use http_client::{AsyncBody, HttpClient, Request, Response, Uri};
|
||||
use rand::distr::Distribution;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned, ser};
|
||||
use serde_json::json;
|
||||
use sha2::{Digest, Sha256};
|
||||
use smol::io::AsyncReadExt;
|
||||
use url::Url;
|
||||
|
||||
pub const CALLBACK_URI: &str = "zed://mcp/auth/callback";
|
||||
|
||||
pub struct OAuthClient {
|
||||
registration: ClientRegistration,
|
||||
server: AuthorizationServer,
|
||||
scope: Option<String>,
|
||||
token: Option<Token>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
struct Token {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_at: Option<Instant>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub async fn init(
|
||||
server_endpoint: &str,
|
||||
www_authenticate: Option<&WwwAuthenticate<'_>>,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<Self> {
|
||||
// https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-discovery
|
||||
// https://modelcontextprotocol.io/specification/draft/basic/authorization#protected-resource-metadata-discovery-requirements
|
||||
let resource =
|
||||
match www_authenticate.and_then(|challenge| challenge.resource_metadata.as_ref()) {
|
||||
Some(url) => ProtectedResource::fetch(url, http_client).await?,
|
||||
None => ProtectedResource::fetch_well_known(server_endpoint, http_client).await?,
|
||||
};
|
||||
|
||||
// https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
|
||||
let auth_server_url = resource
|
||||
.authorization_servers
|
||||
// todo! try others?
|
||||
.first()
|
||||
.context("Resource metadata specified 0 authorization servers")?;
|
||||
|
||||
let server = AuthorizationServer::fetch(auth_server_url, http_client).await?;
|
||||
|
||||
// https://modelcontextprotocol.io/specification/draft/basic/authorization#client-registration-approaches
|
||||
// TODO: Pre-registration from settings?
|
||||
let registration = if server.client_id_metadata_document_supported {
|
||||
todo!("host client id meta doc somewhere");
|
||||
} else if let Some(registration_endpoint) = server.registration_endpoint.as_ref() {
|
||||
Self::register(registration_endpoint, http_client).await?
|
||||
} else {
|
||||
todo!("allow user to specify custom client meta");
|
||||
};
|
||||
|
||||
// https://modelcontextprotocol.io/specification/draft/basic/authorization#scope-selection-strategy
|
||||
let scope = www_authenticate
|
||||
.and_then(|challenge| challenge.scope.as_ref().map(|s| s.to_string()))
|
||||
.or_else(|| {
|
||||
if resource.scopes_supported.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(resource.scopes_supported.join(" "))
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
registration,
|
||||
server,
|
||||
scope,
|
||||
token: None,
|
||||
http_client: http_client.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authorize_url(&self) -> Result<(Url, String)> {
|
||||
let auth_endpoint =
|
||||
self.server.authorization_endpoint.as_ref().context(
|
||||
"Authorization server metadata does not specify an authorization_endpoint",
|
||||
)?;
|
||||
|
||||
let code_verifier = generate_code_verifier();
|
||||
let code_challenge =
|
||||
base64::engine::general_purpose::URL_SAFE.encode(Sha256::digest(&code_verifier));
|
||||
|
||||
let mut authorize_url = Url::parse(&auth_endpoint.to_string())?;
|
||||
|
||||
authorize_url
|
||||
.query_pairs_mut()
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("client_id", &self.registration.client_id)
|
||||
.append_pair("redirect_uri", CALLBACK_URI)
|
||||
.append_pair("code_challenge", &code_challenge)
|
||||
.append_pair("code_challenge_method", "S256")
|
||||
.extend_pairs(self.scope.iter().map(|value| ("scope", value)));
|
||||
|
||||
anyhow::Ok((authorize_url, code_verifier))
|
||||
}
|
||||
|
||||
pub async fn exchange_token(&mut self, code: &str, code_verifier: &str) -> Result<()> {
|
||||
let token_endpoint = self
|
||||
.server
|
||||
.token_endpoint
|
||||
.as_ref()
|
||||
// todo! implicit?
|
||||
.context("Authorization server metadata does not specify a token_endpoint")?;
|
||||
|
||||
let form = url::form_urlencoded::Serializer::new(String::new())
|
||||
.append_pair("grant_type", "authorization_code")
|
||||
.append_pair("code", code)
|
||||
.append_pair("redirect_uri", CALLBACK_URI)
|
||||
.append_pair("client_id", &self.registration.client_id)
|
||||
.append_pair("code_verifier", code_verifier)
|
||||
.finish();
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(token_endpoint.clone())
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.header("Accept", "application/json")
|
||||
.body(AsyncBody::from(form))
|
||||
.context("Failed to build token exchange request")?;
|
||||
|
||||
let requested_at = Instant::now();
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
let token_response: TokenResponse = decode_response_json(&mut response).await?;
|
||||
|
||||
self.token = Some(Token {
|
||||
access_token: token_response.access_token,
|
||||
token_type: token_response.token_type,
|
||||
expires_at: token_response
|
||||
.expires_in
|
||||
.map(|expires_in| requested_at + Duration::from_secs(expires_in)),
|
||||
refresh_token: token_response.refresh_token,
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
async fn register(
|
||||
registration_endpoint: &AbsUri,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<ClientRegistration> {
|
||||
let metadata = json!({
|
||||
"redirect_uris": [CALLBACK_URI],
|
||||
"token_endpoint_auth_method": "none",
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"client_name": "Zed",
|
||||
"client_uri": "https://zed.dev",
|
||||
"logo_uri": "https://zed.dev/_next/static/media/stable-app-logo.9b5f959f.png"
|
||||
});
|
||||
|
||||
post_json(®istration_endpoint.to_string(), metadata, http_client).await
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_code_verifier() -> String {
|
||||
const LENGTH: usize = 64;
|
||||
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
|
||||
|
||||
let dist = rand::distr::slice::Choose::new(ALPHABET).unwrap();
|
||||
|
||||
let bytes: Vec<u8> = dist
|
||||
.sample_iter(rand::rng())
|
||||
.take(LENGTH)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
// SAFETY: All bytes come from ALPHABET which is ASCII
|
||||
unsafe { String::from_utf8_unchecked(bytes) }
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClientRegistration {
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
client_id_issued_at: Option<u64>,
|
||||
client_secret_expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
// Resource Metadata
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ProtectedResource {
|
||||
resource: String,
|
||||
|
||||
#[serde(default)]
|
||||
authorization_servers: Vec<AbsUri>,
|
||||
|
||||
#[serde(default)]
|
||||
scopes_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
bearer_methods_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
resource_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ProtectedResource {
|
||||
pub async fn fetch(url: &str, http_client: &Arc<dyn HttpClient>) -> Result<Self> {
|
||||
get_json(url, http_client)
|
||||
.await
|
||||
.context("Fetching resource metadata")
|
||||
}
|
||||
|
||||
pub async fn fetch_well_known(
|
||||
server_endpoint: &str,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<Self> {
|
||||
let endpoint_uri = server_endpoint.parse::<Uri>()?.try_into()?;
|
||||
let well_known_uri = well_known_pre(&endpoint_uri, "oauth-protected-resource");
|
||||
|
||||
return Self::fetch(&well_known_uri, http_client)
|
||||
.await
|
||||
.context("From well-known URL");
|
||||
}
|
||||
}
|
||||
|
||||
// Server Metadata
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthorizationServer {
|
||||
issuer: String,
|
||||
|
||||
#[serde(default)]
|
||||
authorization_endpoint: Option<AbsUri>,
|
||||
|
||||
#[serde(default)]
|
||||
token_endpoint: Option<AbsUri>,
|
||||
|
||||
#[serde(default)]
|
||||
jwks_uri: Option<AbsUri>,
|
||||
|
||||
#[serde(default)]
|
||||
registration_endpoint: Option<AbsUri>,
|
||||
|
||||
#[serde(default)]
|
||||
scopes_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
response_types_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
grant_types_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
token_endpoint_auth_methods_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
code_challenge_methods_supported: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
client_id_metadata_document_supported: bool,
|
||||
}
|
||||
|
||||
impl AuthorizationServer {
|
||||
pub async fn fetch(
|
||||
issuer_uri: &AbsUri,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<Self, AuthorizationServerMetadataDiscoveryError> {
|
||||
// We must attempt multiple well-known endpoints based on the issuer url
|
||||
//
|
||||
// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery
|
||||
let candidates: [fn(&AbsUri) -> Option<String>; _] = [
|
||||
// 1. OAuth 2.0 Authorization Server Metadata
|
||||
|base| well_known_pre(base, "oauth-authorization-server").into(),
|
||||
// 2. OpenID Connect Discovery 1.0 with path insertion
|
||||
|base| well_known_pre(base, "openid-configuration").into(),
|
||||
// 3. OpenID Connect Discovery 1.0 with path appening
|
||||
|base| {
|
||||
if base.path() != "/" {
|
||||
Some(well_known_post(base, "openid-configuration"))
|
||||
} else {
|
||||
// We already tried the root in the previous step
|
||||
None
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
let mut attempted_urls = Vec::new();
|
||||
|
||||
for build_url in candidates {
|
||||
let Some(url) = build_url(&issuer_uri) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
match get_json(&url, &http_client).await {
|
||||
Ok(meta) => return Ok(meta),
|
||||
Err(err) => {
|
||||
attempted_urls.push((url, err));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(AuthorizationServerMetadataDiscoveryError { attempted_urls })
|
||||
}
|
||||
}
|
||||
|
||||
fn well_known_pre(base_uri: &AbsUri, well_known_segment: &str) -> String {
|
||||
format!(
|
||||
"{}://{}/.well-known/{well_known_segment}{}",
|
||||
base_uri.scheme_str(),
|
||||
base_uri.authority(),
|
||||
base_uri.path().trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
|
||||
fn well_known_post(base_uri: &AbsUri, well_known_segment: &str) -> String {
|
||||
let path = base_uri.path();
|
||||
let separator = if path.ends_with('/') { "" } else { "/" };
|
||||
format!(
|
||||
"{}://{}{}{separator}.well-known/{well_known_segment}",
|
||||
base_uri.scheme_str(),
|
||||
base_uri.authority(),
|
||||
path,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthorizationServerMetadataDiscoveryError {
|
||||
attempted_urls: Vec<(String, anyhow::Error)>,
|
||||
}
|
||||
|
||||
impl Error for AuthorizationServerMetadataDiscoveryError {}
|
||||
|
||||
impl Display for AuthorizationServerMetadataDiscoveryError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"Failed to discover authorization server metadata. Attempted URLs:"
|
||||
)?;
|
||||
|
||||
for (url, err) in &self.attempted_urls {
|
||||
writeln!(f, "- {url}: {err}")?;
|
||||
}
|
||||
|
||||
fmt::Result::Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_json<Out: DeserializeOwned>(
|
||||
url: &str,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<Out> {
|
||||
let mut response = http_client.get(url, AsyncBody::empty(), true).await?;
|
||||
decode_response_json(&mut response).await
|
||||
}
|
||||
|
||||
async fn post_json<In: Serialize, Out: DeserializeOwned>(
|
||||
url: &str,
|
||||
payload: In,
|
||||
http_client: &Arc<dyn HttpClient>,
|
||||
) -> Result<Out> {
|
||||
let mut response = http_client
|
||||
.post_json(url, serde_json::to_string(&payload)?.into())
|
||||
.await?;
|
||||
decode_response_json(&mut response).await
|
||||
}
|
||||
|
||||
async fn decode_response_json<T: DeserializeOwned>(
|
||||
response: &mut Response<AsyncBody>,
|
||||
) -> Result<T> {
|
||||
let mut content = Vec::new();
|
||||
response.body_mut().read_to_end(&mut content).await?;
|
||||
if response.status().is_success() {
|
||||
Ok(serde_json::from_slice(&content)?)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"Status: {}.\nBody: {}",
|
||||
response.status(),
|
||||
String::from_utf8_lossy(&content)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
use abs_uri::AbsUri;
|
||||
|
||||
use crate::transport::http::www_authenticate::WwwAuthenticate;
|
||||
mod abs_uri {
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt::{self, Display},
|
||||
ops::Deref,
|
||||
};
|
||||
|
||||
use http_client::{Uri, http::uri::Authority};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AbsUri(Uri);
|
||||
|
||||
impl AbsUri {
|
||||
pub fn authority(&self) -> &Authority {
|
||||
self.0.authority().unwrap()
|
||||
}
|
||||
|
||||
pub fn scheme_str(&self) -> &str {
|
||||
self.0.scheme_str().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Uri> for AbsUri {
|
||||
fn into(self) -> Uri {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Uri> for AbsUri {
|
||||
type Error = AbsUriError;
|
||||
|
||||
fn try_from(uri: Uri) -> Result<Self, Self::Error> {
|
||||
if uri.scheme().is_none() {
|
||||
return Err(AbsUriError::MissingScheme);
|
||||
}
|
||||
if uri.authority().is_none() {
|
||||
return Err(AbsUriError::MissingAuthority);
|
||||
}
|
||||
Ok(Self(uri))
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for AbsUri {
|
||||
type Target = Uri;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for AbsUri {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
String::deserialize(deserializer)?
|
||||
.parse::<Uri>()
|
||||
.map_err(serde::de::Error::custom)?
|
||||
.try_into()
|
||||
.map_err(|e| serde::de::Error::custom(format!("{e:?}")))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AbsUriError {
|
||||
MissingScheme,
|
||||
MissingAuthority,
|
||||
}
|
||||
|
||||
impl Error for AbsUriError {}
|
||||
|
||||
impl Display for AbsUriError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
AbsUriError::MissingScheme => write!(f, "URI is not absolute: Missing scheme"),
|
||||
AbsUriError::MissingAuthority => {
|
||||
write!(f, "URI is not absolute: Missing authority")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use gpui::{TestAppContext, prelude::*};
|
||||
use http_client::{FakeHttpClient, Request, Response};
|
||||
|
||||
#[gpui::test]
|
||||
async fn fetch_server_metadata_chain(cx: &mut TestAppContext) {
|
||||
expect_fallback_chain(
|
||||
"https://auth.example.com/tenant/123",
|
||||
&[
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server/tenant/123",
|
||||
"https://auth.example.com/.well-known/openid-configuration/tenant/123",
|
||||
"https://auth.example.com/tenant/123/.well-known/openid-configuration",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
expect_fallback_chain(
|
||||
"https://auth.example.com/tenant/123/",
|
||||
&[
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server/tenant/123",
|
||||
"https://auth.example.com/.well-known/openid-configuration/tenant/123",
|
||||
"https://auth.example.com/tenant/123/.well-known/openid-configuration",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
expect_fallback_chain(
|
||||
"https://auth.example.com",
|
||||
&[
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
"https://auth.example.com/.well-known/openid-configuration",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn expect_fallback_chain(issuer_uri: &str, urls: &[&str], cx: &mut TestAppContext) {
|
||||
let issuer_uri: AbsUri = issuer_uri.parse::<Uri>().unwrap().try_into().unwrap();
|
||||
let (client, mut request_rx) = fake_client();
|
||||
|
||||
for i in 0..urls.len() {
|
||||
let issuer_uri = issuer_uri.clone();
|
||||
let client = client.clone();
|
||||
let fetch_task = cx.background_spawn(async move {
|
||||
AuthorizationServer::fetch(&issuer_uri, &client).await
|
||||
});
|
||||
|
||||
for request_url in &urls[..i] {
|
||||
let request = request_rx.next().await.unwrap();
|
||||
assert_eq!(request.uri, *request_url);
|
||||
respond(request, not_found());
|
||||
}
|
||||
|
||||
let request = request_rx.next().await.unwrap();
|
||||
assert_eq!(request.uri, *urls[i]);
|
||||
respond(
|
||||
request,
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(valid_metadata_json(
|
||||
"https://auth.example.com",
|
||||
)))
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let metadata = fetch_task.await.expect("fetch should succeed");
|
||||
assert_eq!(metadata.issuer, "https://auth.example.com");
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn fetch_server_metadata_openid_root_stops_on_fail(cx: &mut TestAppContext) {
|
||||
let (client, mut requests) = fake_client();
|
||||
let http_client = client.clone();
|
||||
|
||||
let fetch_task = cx.background_spawn(async move {
|
||||
let issuer_uri: AbsUri = "https://auth.example.com"
|
||||
.parse::<Uri>()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
AuthorizationServer::fetch(&issuer_uri, &http_client).await
|
||||
});
|
||||
|
||||
let request = requests.next().await.expect("Expected first request");
|
||||
assert_eq!(
|
||||
request.uri,
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server"
|
||||
);
|
||||
respond(request, not_found());
|
||||
|
||||
let request = requests.next().await.expect("Expected second request");
|
||||
assert_eq!(
|
||||
request.uri,
|
||||
"https://auth.example.com/.well-known/openid-configuration"
|
||||
);
|
||||
respond(request, not_found());
|
||||
|
||||
// should not attempt well_known_post since it'd be the same as well_known_pre
|
||||
let error = fetch_task.await.expect_err("fetch should fail");
|
||||
assert_eq!(error.attempted_urls.len(), 2);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn fetch_server_metadata_all_fail(cx: &mut TestAppContext) {
|
||||
let (client, mut requests) = fake_client();
|
||||
let http_client = client.clone();
|
||||
|
||||
let fetch_task = cx.background_spawn(async move {
|
||||
let issuer_uri: AbsUri = "https://auth.example.com/tenant/123"
|
||||
.parse::<Uri>()
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
AuthorizationServer::fetch(&issuer_uri, &http_client).await
|
||||
});
|
||||
|
||||
for _ in 0..3 {
|
||||
let request = requests.next().await.expect("Expected request");
|
||||
respond(request, not_found());
|
||||
}
|
||||
|
||||
let error = fetch_task.await.expect_err("fetch should fail");
|
||||
assert_eq!(error.attempted_urls.len(), 3);
|
||||
}
|
||||
|
||||
struct FakeRequest {
|
||||
uri: String,
|
||||
respond: oneshot::Sender<Response<AsyncBody>>,
|
||||
}
|
||||
|
||||
fn fake_client() -> (
|
||||
Arc<http_client::HttpClientWithUrl>,
|
||||
mpsc::UnboundedReceiver<FakeRequest>,
|
||||
) {
|
||||
let (request_sender, request_receiver) = mpsc::unbounded::<FakeRequest>();
|
||||
|
||||
let client = FakeHttpClient::create(move |req: Request<AsyncBody>| {
|
||||
let request_sender = request_sender.clone();
|
||||
async move {
|
||||
let (respond, response_receiver) = oneshot::channel();
|
||||
request_sender
|
||||
.unbounded_send(FakeRequest {
|
||||
uri: req.uri().to_string(),
|
||||
respond,
|
||||
})
|
||||
.expect("Test receiver dropped");
|
||||
|
||||
response_receiver
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Test dropped response sender"))
|
||||
}
|
||||
});
|
||||
|
||||
(client, request_receiver)
|
||||
}
|
||||
|
||||
fn not_found() -> Response<AsyncBody> {
|
||||
Response::builder()
|
||||
.status(404)
|
||||
.body(AsyncBody::from("Not found".to_string()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn valid_metadata_json(issuer: &str) -> String {
|
||||
serde_json::json!({
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": format!("{}/authorize", issuer),
|
||||
"token_endpoint": format!("{}/token", issuer),
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn respond(request: FakeRequest, response: Response<AsyncBody>) {
|
||||
request.respond.send(response).ok();
|
||||
}
|
||||
}
|
||||
255
crates/context_server/src/transport/http/www_authenticate.rs
Normal file
255
crates/context_server/src/transport/http/www_authenticate.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use std::borrow::Cow;
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
pub struct WwwAuthenticate<'a> {
|
||||
pub realm: Option<Cow<'a, str>>,
|
||||
pub scope: Option<Cow<'a, str>>,
|
||||
pub error: Option<Cow<'a, str>>,
|
||||
pub error_description: Option<Cow<'a, str>>,
|
||||
pub error_uri: Option<Cow<'a, str>>,
|
||||
pub resource_metadata: Option<Cow<'a, str>>,
|
||||
}
|
||||
|
||||
const BEARER_SCHEME: &str = "Bearer";
|
||||
|
||||
impl<'a> WwwAuthenticate<'a> {
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
// Header format (simplified):
|
||||
// Bearer realm="example", error="invalid_token", error_description="...", error_uri="..."
|
||||
let input = input.trim_ascii_start();
|
||||
|
||||
let (scheme, mut input) = input
|
||||
.trim_start()
|
||||
.split_once(|c: char| c.is_ascii_whitespace())
|
||||
.unwrap_or((input, ""));
|
||||
|
||||
// We only parse Bearer challenges as defined by RFC 6750 section 3.
|
||||
if !scheme.eq_ignore_ascii_case(BEARER_SCHEME) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut challenge = Self::default();
|
||||
|
||||
loop {
|
||||
input = input.trim_ascii_start();
|
||||
|
||||
if input.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop at a subsequent Bearer challenge in a combined header.
|
||||
if let Some(sub) = input.strip_prefix(BEARER_SCHEME)
|
||||
&& sub
|
||||
.chars()
|
||||
.next()
|
||||
.is_some_and(|character| character.is_ascii_whitespace())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
let (name, rest) = parse_token(input)?;
|
||||
let mut rest = rest.trim_ascii_start();
|
||||
|
||||
rest = rest.strip_prefix('=')?.trim_ascii_start();
|
||||
|
||||
let (value, rest) = parse_value(rest)?;
|
||||
input = rest;
|
||||
|
||||
match name {
|
||||
"realm" => challenge.realm = Some(value),
|
||||
"scope" => challenge.scope = Some(value),
|
||||
"error" => challenge.error = Some(value),
|
||||
"error_description" => challenge.error_description = Some(value),
|
||||
"error_uri" => {
|
||||
challenge.error_uri = Some(value);
|
||||
}
|
||||
"resource_metadata" => {
|
||||
challenge.resource_metadata = Some(value);
|
||||
}
|
||||
_ => {
|
||||
// Ignore extension auth-params.
|
||||
}
|
||||
}
|
||||
|
||||
input = input.trim_start();
|
||||
if let Some(after_comma) = input.strip_prefix(',') {
|
||||
input = after_comma;
|
||||
} else {
|
||||
// If there's no comma, we either reached the end or encountered something invalid.
|
||||
if !input.is_empty() {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(challenge)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_token(input: &str) -> Option<(&str, &str)> {
|
||||
let bytes = input.as_bytes();
|
||||
let mut end = 0;
|
||||
|
||||
while end < bytes.len() && is_tchar(bytes[end]) {
|
||||
end += 1;
|
||||
}
|
||||
|
||||
if end == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((&input[..end], &input[end..]))
|
||||
}
|
||||
|
||||
fn is_tchar(byte: u8) -> bool {
|
||||
matches!(
|
||||
byte,
|
||||
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~'
|
||||
| b'0'..=b'9'
|
||||
| b'A'..=b'Z'
|
||||
| b'a'..=b'z'
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_value<'a>(input: &'a str) -> Option<(Cow<'a, str>, &'a str)> {
|
||||
if let Some(rest) = input.strip_prefix('"') {
|
||||
parse_quoted_value(rest)
|
||||
} else {
|
||||
let (token, rest) = parse_token(input)?;
|
||||
Some((Cow::Borrowed(token), rest))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_quoted_value<'a>(input: &'a str) -> Option<(Cow<'a, str>, &'a str)> {
|
||||
let mut output: Option<String> = None;
|
||||
let mut segment_start: usize = 0;
|
||||
|
||||
let mut iter = input.as_bytes().iter().enumerate();
|
||||
|
||||
while let Some((index, byte)) = iter.next() {
|
||||
match byte {
|
||||
b'"' => {
|
||||
let remainder = &input[index + 1..];
|
||||
|
||||
if let Some(mut output) = output {
|
||||
output.push_str(&input[segment_start..index]);
|
||||
return Some((Cow::Owned(output), remainder));
|
||||
}
|
||||
|
||||
return Some((Cow::Borrowed(&input[..index]), remainder));
|
||||
}
|
||||
b'\\' => {
|
||||
let (escaped_index, escaped_byte) = iter.next()?;
|
||||
|
||||
let output = output.get_or_insert_with(String::new);
|
||||
output.push_str(&input[segment_start..index]);
|
||||
output.push(*escaped_byte as char);
|
||||
|
||||
segment_start = escaped_index + 1;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parses_empty_bearer_challenge() {
|
||||
let challenge = WwwAuthenticate::parse("Bearer").expect("should parse Bearer scheme");
|
||||
assert_eq!(challenge, WwwAuthenticate::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_non_bearer_scheme() {
|
||||
assert!(WwwAuthenticate::parse("Basic realm=\"example\"").is_none());
|
||||
assert!(WwwAuthenticate::parse("Digest realm=\"example\"").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_known_parameters_with_quoted_strings_and_tokens() {
|
||||
let challenge = WwwAuthenticate::parse(
|
||||
"Bearer realm=\"example\", scope=\"read write\", error=invalid_token, error_description=\"The access token expired\"",
|
||||
)
|
||||
.expect("should parse");
|
||||
|
||||
assert_eq!(
|
||||
challenge,
|
||||
WwwAuthenticate {
|
||||
realm: Some(Cow::Borrowed("example")),
|
||||
scope: Some(Cow::Borrowed("read write")),
|
||||
error: Some(Cow::Borrowed("invalid_token")),
|
||||
error_description: Some(Cow::Borrowed("The access token expired")),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quoted_string_allows_commas_and_backslash_escapes() {
|
||||
let challenge = WwwAuthenticate::parse(
|
||||
"Bearer error_description=\"contains, comma and a quote: \\\" and a backslash: \\\\\"",
|
||||
)
|
||||
.expect("should parse");
|
||||
|
||||
assert_eq!(
|
||||
challenge,
|
||||
WwwAuthenticate {
|
||||
error_description: Some(Cow::Owned(
|
||||
"contains, comma and a quote: \" and a backslash: \\".to_string()
|
||||
)),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_unknown_extension_parameters() {
|
||||
let challenge =
|
||||
WwwAuthenticate::parse("Bearer realm=\"example\", foo=\"bar\"").expect("should parse");
|
||||
|
||||
assert_eq!(
|
||||
challenge,
|
||||
WwwAuthenticate {
|
||||
realm: Some(Cow::Borrowed("example")),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stops_at_subsequent_bearer_challenge_in_combined_header_value() {
|
||||
let challenge = WwwAuthenticate::parse(
|
||||
"Bearer realm=\"first\", error=\"invalid_token\", Bearer realm=\"second\"",
|
||||
)
|
||||
.expect("should parse");
|
||||
|
||||
assert_eq!(
|
||||
challenge,
|
||||
WwwAuthenticate {
|
||||
realm: Some(Cow::Borrowed("first")),
|
||||
error: Some(Cow::Borrowed("invalid_token")),
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_on_invalid_trailing_garbage() {
|
||||
assert!(WwwAuthenticate::parse("Bearer realm=\"example\" garbage").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_on_missing_equals() {
|
||||
assert!(WwwAuthenticate::parse("Bearer realm \"example\"").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_on_unterminated_quoted_string() {
|
||||
assert!(WwwAuthenticate::parse("Bearer realm=\"example").is_none());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user