Compare commits

...

10 Commits

Author SHA1 Message Date
Agus Zubiaga
caf46d8ddf wip 2025-12-31 15:16:19 -03:00
Agus Zubiaga
808d262c3a Add explicit Accept: application/json 2025-12-31 10:57:18 -03:00
Agus Zubiaga
944634e98c Checkpoint: authorize and exchange token 2025-12-31 10:55:36 -03:00
Agus Zubiaga
28be655495 Scope selection strategy 2025-12-30 17:22:44 -03:00
Agus Zubiaga
935b16db1b Checkpoint: Sketching flow 2025-12-30 17:05:11 -03:00
Agus Zubiaga
f5f2f10e13 Test server metadata fallback 2025-12-28 16:35:32 -03:00
Agus Zubiaga
338dd62e0e Fetch authorization server metadata 2025-12-28 12:50:29 -03:00
Agus Zubiaga
f2cf745c3d Fetch ProtectedResourceMetadata 2025-12-27 15:12:56 -03:00
Agus Zubiaga
2493e42564 Add resource_metadata field 2025-12-27 13:59:29 -03:00
Agus Zubiaga
53d14b32e2 Parse WWW-Authenticate 2025-12-27 13:49:54 -03:00
7 changed files with 971 additions and 4 deletions

3
Cargo.lock generated
View File

@@ -3650,6 +3650,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"base64 0.22.1",
"collections",
"futures 0.3.31",
"gpui",
@@ -3658,10 +3659,12 @@ dependencies = [
"net",
"parking_lot",
"postage",
"rand 0.9.2",
"schemars",
"serde",
"serde_json",
"settings",
"sha2",
"slotmap",
"smol",
"tempfile",

View File

@@ -17,6 +17,7 @@ test-support = ["gpui/test-support"]
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
base64.workspace =true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -25,16 +26,18 @@ log.workspace = true
net.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
schemars.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
sha2.workspace = true
slotmap.workspace = true
smol.workspace = true
tempfile.workspace = true
terminal.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
terminal.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -6,6 +6,8 @@ pub mod test;
pub mod transport;
pub mod types;
pub use transport::UnauthorizedError;
use collections::HashMap;
use http_client::HttpClient;
use std::path::Path;

View File

@@ -4,7 +4,7 @@ mod stdio_transport;
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use std::{error::Error, fmt, pin::Pin};
pub use http::*;
pub use stdio_transport::*;
@@ -15,3 +15,16 @@ pub trait Transport: Send + Sync {
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>>;
}
#[derive(Debug)]
pub struct UnauthorizedError {
pub www_authenticate_header: Option<String>,
}
impl Error for UnauthorizedError {}
impl fmt::Display for UnauthorizedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthorized")
}
}

View File

@@ -1,3 +1,6 @@
mod auth;
mod www_authenticate;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use collections::HashMap;
@@ -8,7 +11,10 @@ use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
use crate::transport::Transport;
use crate::transport::{
Transport, UnauthorizedError,
http::{auth::OAuthClient, www_authenticate::WwwAuthenticate},
};
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
@@ -130,6 +136,16 @@ impl HttpTransport {
// Accepted - notification acknowledged, no response needed
log::debug!("Notification accepted");
}
status if status.as_u16() == 401 => {
let www_authenticate_header = response
.headers()
.get("WWW-Authenticate")
.and_then(|value| Some(value.to_str().ok()?.to_string()));
anyhow::bail!(UnauthorizedError {
www_authenticate_header
})
}
_ => {
let mut error_body = String::new();
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;

View File

@@ -0,0 +1,675 @@
use std::{
borrow::Cow,
error::Error,
fmt::{self, Display},
ops::Deref,
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{Context as _, Result};
use base64::Engine as _;
use http_client::{AsyncBody, HttpClient, Request, Response, Uri};
use rand::distr::Distribution;
use serde::{Deserialize, Serialize, de::DeserializeOwned, ser};
use serde_json::json;
use sha2::{Digest, Sha256};
use smol::io::AsyncReadExt;
use url::Url;
pub const CALLBACK_URI: &str = "zed://mcp/auth/callback";
pub struct OAuthClient {
registration: ClientRegistration,
server: AuthorizationServer,
scope: Option<String>,
token: Option<Token>,
http_client: Arc<dyn HttpClient>,
}
struct Token {
access_token: String,
token_type: String,
expires_at: Option<Instant>,
refresh_token: Option<String>,
}
impl OAuthClient {
pub async fn init(
server_endpoint: &str,
www_authenticate: Option<&WwwAuthenticate<'_>>,
http_client: &Arc<dyn HttpClient>,
) -> Result<Self> {
// https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-discovery
// https://modelcontextprotocol.io/specification/draft/basic/authorization#protected-resource-metadata-discovery-requirements
let resource =
match www_authenticate.and_then(|challenge| challenge.resource_metadata.as_ref()) {
Some(url) => ProtectedResource::fetch(url, http_client).await?,
None => ProtectedResource::fetch_well_known(server_endpoint, http_client).await?,
};
// https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
let auth_server_url = resource
.authorization_servers
// todo! try others?
.first()
.context("Resource metadata specified 0 authorization servers")?;
let server = AuthorizationServer::fetch(auth_server_url, http_client).await?;
// https://modelcontextprotocol.io/specification/draft/basic/authorization#client-registration-approaches
// TODO: Pre-registration from settings?
let registration = if server.client_id_metadata_document_supported {
todo!("host client id meta doc somewhere");
} else if let Some(registration_endpoint) = server.registration_endpoint.as_ref() {
Self::register(registration_endpoint, http_client).await?
} else {
todo!("allow user to specify custom client meta");
};
// https://modelcontextprotocol.io/specification/draft/basic/authorization#scope-selection-strategy
let scope = www_authenticate
.and_then(|challenge| challenge.scope.as_ref().map(|s| s.to_string()))
.or_else(|| {
if resource.scopes_supported.is_empty() {
None
} else {
Some(resource.scopes_supported.join(" "))
}
});
Ok(Self {
registration,
server,
scope,
token: None,
http_client: http_client.clone(),
})
}
pub fn authorize_url(&self) -> Result<(Url, String)> {
let auth_endpoint =
self.server.authorization_endpoint.as_ref().context(
"Authorization server metadata does not specify an authorization_endpoint",
)?;
let code_verifier = generate_code_verifier();
let code_challenge =
base64::engine::general_purpose::URL_SAFE.encode(Sha256::digest(&code_verifier));
let mut authorize_url = Url::parse(&auth_endpoint.to_string())?;
authorize_url
.query_pairs_mut()
.append_pair("response_type", "code")
.append_pair("client_id", &self.registration.client_id)
.append_pair("redirect_uri", CALLBACK_URI)
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256")
.extend_pairs(self.scope.iter().map(|value| ("scope", value)));
anyhow::Ok((authorize_url, code_verifier))
}
pub async fn exchange_token(&mut self, code: &str, code_verifier: &str) -> Result<()> {
let token_endpoint = self
.server
.token_endpoint
.as_ref()
// todo! implicit?
.context("Authorization server metadata does not specify a token_endpoint")?;
let form = url::form_urlencoded::Serializer::new(String::new())
.append_pair("grant_type", "authorization_code")
.append_pair("code", code)
.append_pair("redirect_uri", CALLBACK_URI)
.append_pair("client_id", &self.registration.client_id)
.append_pair("code_verifier", code_verifier)
.finish();
let request = Request::builder()
.uri(token_endpoint.clone())
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Accept", "application/json")
.body(AsyncBody::from(form))
.context("Failed to build token exchange request")?;
let requested_at = Instant::now();
let mut response = self.http_client.send(request).await?;
let token_response: TokenResponse = decode_response_json(&mut response).await?;
self.token = Some(Token {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_at: token_response
.expires_in
.map(|expires_in| requested_at + Duration::from_secs(expires_in)),
refresh_token: token_response.refresh_token,
});
anyhow::Ok(())
}
async fn register(
registration_endpoint: &AbsUri,
http_client: &Arc<dyn HttpClient>,
) -> Result<ClientRegistration> {
let metadata = json!({
"redirect_uris": [CALLBACK_URI],
"token_endpoint_auth_method": "none",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"client_name": "Zed",
"client_uri": "https://zed.dev",
"logo_uri": "https://zed.dev/_next/static/media/stable-app-logo.9b5f959f.png"
});
post_json(&registration_endpoint.to_string(), metadata, http_client).await
}
}
fn generate_code_verifier() -> String {
const LENGTH: usize = 64;
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
let dist = rand::distr::slice::Choose::new(ALPHABET).unwrap();
let bytes: Vec<u8> = dist
.sample_iter(rand::rng())
.take(LENGTH)
.copied()
.collect();
// SAFETY: All bytes come from ALPHABET which is ASCII
unsafe { String::from_utf8_unchecked(bytes) }
}
#[derive(Deserialize)]
struct ClientRegistration {
client_id: String,
client_secret: Option<String>,
client_id_issued_at: Option<u64>,
client_secret_expires_at: Option<u64>,
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
}
// Resource Metadata
#[derive(Deserialize)]
pub struct ProtectedResource {
resource: String,
#[serde(default)]
authorization_servers: Vec<AbsUri>,
#[serde(default)]
scopes_supported: Vec<String>,
#[serde(default)]
bearer_methods_supported: Vec<String>,
#[serde(default)]
resource_name: Option<String>,
}
impl ProtectedResource {
pub async fn fetch(url: &str, http_client: &Arc<dyn HttpClient>) -> Result<Self> {
get_json(url, http_client)
.await
.context("Fetching resource metadata")
}
pub async fn fetch_well_known(
server_endpoint: &str,
http_client: &Arc<dyn HttpClient>,
) -> Result<Self> {
let endpoint_uri = server_endpoint.parse::<Uri>()?.try_into()?;
let well_known_uri = well_known_pre(&endpoint_uri, "oauth-protected-resource");
return Self::fetch(&well_known_uri, http_client)
.await
.context("From well-known URL");
}
}
// Server Metadata
#[derive(Debug, Deserialize)]
pub struct AuthorizationServer {
issuer: String,
#[serde(default)]
authorization_endpoint: Option<AbsUri>,
#[serde(default)]
token_endpoint: Option<AbsUri>,
#[serde(default)]
jwks_uri: Option<AbsUri>,
#[serde(default)]
registration_endpoint: Option<AbsUri>,
#[serde(default)]
scopes_supported: Vec<String>,
#[serde(default)]
response_types_supported: Vec<String>,
#[serde(default)]
grant_types_supported: Vec<String>,
#[serde(default)]
token_endpoint_auth_methods_supported: Vec<String>,
#[serde(default)]
code_challenge_methods_supported: Vec<String>,
#[serde(default)]
client_id_metadata_document_supported: bool,
}
impl AuthorizationServer {
pub async fn fetch(
issuer_uri: &AbsUri,
http_client: &Arc<dyn HttpClient>,
) -> Result<Self, AuthorizationServerMetadataDiscoveryError> {
// We must attempt multiple well-known endpoints based on the issuer url
//
// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery
let candidates: [fn(&AbsUri) -> Option<String>; _] = [
// 1. OAuth 2.0 Authorization Server Metadata
|base| well_known_pre(base, "oauth-authorization-server").into(),
// 2. OpenID Connect Discovery 1.0 with path insertion
|base| well_known_pre(base, "openid-configuration").into(),
// 3. OpenID Connect Discovery 1.0 with path appening
|base| {
if base.path() != "/" {
Some(well_known_post(base, "openid-configuration"))
} else {
// We already tried the root in the previous step
None
}
},
];
let mut attempted_urls = Vec::new();
for build_url in candidates {
let Some(url) = build_url(&issuer_uri) else {
continue;
};
match get_json(&url, &http_client).await {
Ok(meta) => return Ok(meta),
Err(err) => {
attempted_urls.push((url, err));
}
}
}
Err(AuthorizationServerMetadataDiscoveryError { attempted_urls })
}
}
fn well_known_pre(base_uri: &AbsUri, well_known_segment: &str) -> String {
format!(
"{}://{}/.well-known/{well_known_segment}{}",
base_uri.scheme_str(),
base_uri.authority(),
base_uri.path().trim_end_matches('/')
)
}
fn well_known_post(base_uri: &AbsUri, well_known_segment: &str) -> String {
let path = base_uri.path();
let separator = if path.ends_with('/') { "" } else { "/" };
format!(
"{}://{}{}{separator}.well-known/{well_known_segment}",
base_uri.scheme_str(),
base_uri.authority(),
path,
)
}
#[derive(Debug)]
pub struct AuthorizationServerMetadataDiscoveryError {
attempted_urls: Vec<(String, anyhow::Error)>,
}
impl Error for AuthorizationServerMetadataDiscoveryError {}
impl Display for AuthorizationServerMetadataDiscoveryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Failed to discover authorization server metadata. Attempted URLs:"
)?;
for (url, err) in &self.attempted_urls {
writeln!(f, "- {url}: {err}")?;
}
fmt::Result::Ok(())
}
}
async fn get_json<Out: DeserializeOwned>(
url: &str,
http_client: &Arc<dyn HttpClient>,
) -> Result<Out> {
let mut response = http_client.get(url, AsyncBody::empty(), true).await?;
decode_response_json(&mut response).await
}
async fn post_json<In: Serialize, Out: DeserializeOwned>(
url: &str,
payload: In,
http_client: &Arc<dyn HttpClient>,
) -> Result<Out> {
let mut response = http_client
.post_json(url, serde_json::to_string(&payload)?.into())
.await?;
decode_response_json(&mut response).await
}
async fn decode_response_json<T: DeserializeOwned>(
response: &mut Response<AsyncBody>,
) -> Result<T> {
let mut content = Vec::new();
response.body_mut().read_to_end(&mut content).await?;
if response.status().is_success() {
Ok(serde_json::from_slice(&content)?)
} else {
anyhow::bail!(
"Status: {}.\nBody: {}",
response.status(),
String::from_utf8_lossy(&content)
);
}
}
use abs_uri::AbsUri;
use crate::transport::http::www_authenticate::WwwAuthenticate;
mod abs_uri {
use std::{
error::Error,
fmt::{self, Display},
ops::Deref,
};
use http_client::{Uri, http::uri::Authority};
use serde::Deserialize;
#[derive(Debug, Clone)]
pub struct AbsUri(Uri);
impl AbsUri {
pub fn authority(&self) -> &Authority {
self.0.authority().unwrap()
}
pub fn scheme_str(&self) -> &str {
self.0.scheme_str().unwrap()
}
}
impl Into<Uri> for AbsUri {
fn into(self) -> Uri {
self.0
}
}
impl TryFrom<Uri> for AbsUri {
type Error = AbsUriError;
fn try_from(uri: Uri) -> Result<Self, Self::Error> {
if uri.scheme().is_none() {
return Err(AbsUriError::MissingScheme);
}
if uri.authority().is_none() {
return Err(AbsUriError::MissingAuthority);
}
Ok(Self(uri))
}
}
impl Deref for AbsUri {
type Target = Uri;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de> Deserialize<'de> for AbsUri {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
String::deserialize(deserializer)?
.parse::<Uri>()
.map_err(serde::de::Error::custom)?
.try_into()
.map_err(|e| serde::de::Error::custom(format!("{e:?}")))
}
}
#[derive(Debug)]
pub enum AbsUriError {
MissingScheme,
MissingAuthority,
}
impl Error for AbsUriError {}
impl Display for AbsUriError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AbsUriError::MissingScheme => write!(f, "URI is not absolute: Missing scheme"),
AbsUriError::MissingAuthority => {
write!(f, "URI is not absolute: Missing authority")
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use futures::StreamExt;
use futures::channel::{mpsc, oneshot};
use gpui::{TestAppContext, prelude::*};
use http_client::{FakeHttpClient, Request, Response};
#[gpui::test]
async fn fetch_server_metadata_chain(cx: &mut TestAppContext) {
expect_fallback_chain(
"https://auth.example.com/tenant/123",
&[
"https://auth.example.com/.well-known/oauth-authorization-server/tenant/123",
"https://auth.example.com/.well-known/openid-configuration/tenant/123",
"https://auth.example.com/tenant/123/.well-known/openid-configuration",
],
cx,
)
.await;
expect_fallback_chain(
"https://auth.example.com/tenant/123/",
&[
"https://auth.example.com/.well-known/oauth-authorization-server/tenant/123",
"https://auth.example.com/.well-known/openid-configuration/tenant/123",
"https://auth.example.com/tenant/123/.well-known/openid-configuration",
],
cx,
)
.await;
expect_fallback_chain(
"https://auth.example.com",
&[
"https://auth.example.com/.well-known/oauth-authorization-server",
"https://auth.example.com/.well-known/openid-configuration",
],
cx,
)
.await;
}
async fn expect_fallback_chain(issuer_uri: &str, urls: &[&str], cx: &mut TestAppContext) {
let issuer_uri: AbsUri = issuer_uri.parse::<Uri>().unwrap().try_into().unwrap();
let (client, mut request_rx) = fake_client();
for i in 0..urls.len() {
let issuer_uri = issuer_uri.clone();
let client = client.clone();
let fetch_task = cx.background_spawn(async move {
AuthorizationServer::fetch(&issuer_uri, &client).await
});
for request_url in &urls[..i] {
let request = request_rx.next().await.unwrap();
assert_eq!(request.uri, *request_url);
respond(request, not_found());
}
let request = request_rx.next().await.unwrap();
assert_eq!(request.uri, *urls[i]);
respond(
request,
Response::builder()
.status(200)
.header("Content-Type", "application/json")
.body(AsyncBody::from(valid_metadata_json(
"https://auth.example.com",
)))
.unwrap(),
);
let metadata = fetch_task.await.expect("fetch should succeed");
assert_eq!(metadata.issuer, "https://auth.example.com");
}
}
#[gpui::test]
async fn fetch_server_metadata_openid_root_stops_on_fail(cx: &mut TestAppContext) {
let (client, mut requests) = fake_client();
let http_client = client.clone();
let fetch_task = cx.background_spawn(async move {
let issuer_uri: AbsUri = "https://auth.example.com"
.parse::<Uri>()
.unwrap()
.try_into()
.unwrap();
AuthorizationServer::fetch(&issuer_uri, &http_client).await
});
let request = requests.next().await.expect("Expected first request");
assert_eq!(
request.uri,
"https://auth.example.com/.well-known/oauth-authorization-server"
);
respond(request, not_found());
let request = requests.next().await.expect("Expected second request");
assert_eq!(
request.uri,
"https://auth.example.com/.well-known/openid-configuration"
);
respond(request, not_found());
// should not attempt well_known_post since it'd be the same as well_known_pre
let error = fetch_task.await.expect_err("fetch should fail");
assert_eq!(error.attempted_urls.len(), 2);
}
#[gpui::test]
async fn fetch_server_metadata_all_fail(cx: &mut TestAppContext) {
let (client, mut requests) = fake_client();
let http_client = client.clone();
let fetch_task = cx.background_spawn(async move {
let issuer_uri: AbsUri = "https://auth.example.com/tenant/123"
.parse::<Uri>()
.unwrap()
.try_into()
.unwrap();
AuthorizationServer::fetch(&issuer_uri, &http_client).await
});
for _ in 0..3 {
let request = requests.next().await.expect("Expected request");
respond(request, not_found());
}
let error = fetch_task.await.expect_err("fetch should fail");
assert_eq!(error.attempted_urls.len(), 3);
}
struct FakeRequest {
uri: String,
respond: oneshot::Sender<Response<AsyncBody>>,
}
fn fake_client() -> (
Arc<http_client::HttpClientWithUrl>,
mpsc::UnboundedReceiver<FakeRequest>,
) {
let (request_sender, request_receiver) = mpsc::unbounded::<FakeRequest>();
let client = FakeHttpClient::create(move |req: Request<AsyncBody>| {
let request_sender = request_sender.clone();
async move {
let (respond, response_receiver) = oneshot::channel();
request_sender
.unbounded_send(FakeRequest {
uri: req.uri().to_string(),
respond,
})
.expect("Test receiver dropped");
response_receiver
.await
.map_err(|_| anyhow::anyhow!("Test dropped response sender"))
}
});
(client, request_receiver)
}
fn not_found() -> Response<AsyncBody> {
Response::builder()
.status(404)
.body(AsyncBody::from("Not found".to_string()))
.unwrap()
}
fn valid_metadata_json(issuer: &str) -> String {
serde_json::json!({
"issuer": issuer,
"authorization_endpoint": format!("{}/authorize", issuer),
"token_endpoint": format!("{}/token", issuer),
})
.to_string()
}
fn respond(request: FakeRequest, response: Response<AsyncBody>) {
request.respond.send(response).ok();
}
}

View File

@@ -0,0 +1,255 @@
use std::borrow::Cow;
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct WwwAuthenticate<'a> {
pub realm: Option<Cow<'a, str>>,
pub scope: Option<Cow<'a, str>>,
pub error: Option<Cow<'a, str>>,
pub error_description: Option<Cow<'a, str>>,
pub error_uri: Option<Cow<'a, str>>,
pub resource_metadata: Option<Cow<'a, str>>,
}
const BEARER_SCHEME: &str = "Bearer";
impl<'a> WwwAuthenticate<'a> {
pub fn parse(input: &'a str) -> Option<Self> {
// Header format (simplified):
// Bearer realm="example", error="invalid_token", error_description="...", error_uri="..."
let input = input.trim_ascii_start();
let (scheme, mut input) = input
.trim_start()
.split_once(|c: char| c.is_ascii_whitespace())
.unwrap_or((input, ""));
// We only parse Bearer challenges as defined by RFC 6750 section 3.
if !scheme.eq_ignore_ascii_case(BEARER_SCHEME) {
return None;
}
let mut challenge = Self::default();
loop {
input = input.trim_ascii_start();
if input.is_empty() {
break;
}
// Stop at a subsequent Bearer challenge in a combined header.
if let Some(sub) = input.strip_prefix(BEARER_SCHEME)
&& sub
.chars()
.next()
.is_some_and(|character| character.is_ascii_whitespace())
{
break;
}
let (name, rest) = parse_token(input)?;
let mut rest = rest.trim_ascii_start();
rest = rest.strip_prefix('=')?.trim_ascii_start();
let (value, rest) = parse_value(rest)?;
input = rest;
match name {
"realm" => challenge.realm = Some(value),
"scope" => challenge.scope = Some(value),
"error" => challenge.error = Some(value),
"error_description" => challenge.error_description = Some(value),
"error_uri" => {
challenge.error_uri = Some(value);
}
"resource_metadata" => {
challenge.resource_metadata = Some(value);
}
_ => {
// Ignore extension auth-params.
}
}
input = input.trim_start();
if let Some(after_comma) = input.strip_prefix(',') {
input = after_comma;
} else {
// If there's no comma, we either reached the end or encountered something invalid.
if !input.is_empty() {
return None;
}
}
}
Some(challenge)
}
}
fn parse_token(input: &str) -> Option<(&str, &str)> {
let bytes = input.as_bytes();
let mut end = 0;
while end < bytes.len() && is_tchar(bytes[end]) {
end += 1;
}
if end == 0 {
return None;
}
Some((&input[..end], &input[end..]))
}
fn is_tchar(byte: u8) -> bool {
matches!(
byte,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~'
| b'0'..=b'9'
| b'A'..=b'Z'
| b'a'..=b'z'
)
}
fn parse_value<'a>(input: &'a str) -> Option<(Cow<'a, str>, &'a str)> {
if let Some(rest) = input.strip_prefix('"') {
parse_quoted_value(rest)
} else {
let (token, rest) = parse_token(input)?;
Some((Cow::Borrowed(token), rest))
}
}
fn parse_quoted_value<'a>(input: &'a str) -> Option<(Cow<'a, str>, &'a str)> {
let mut output: Option<String> = None;
let mut segment_start: usize = 0;
let mut iter = input.as_bytes().iter().enumerate();
while let Some((index, byte)) = iter.next() {
match byte {
b'"' => {
let remainder = &input[index + 1..];
if let Some(mut output) = output {
output.push_str(&input[segment_start..index]);
return Some((Cow::Owned(output), remainder));
}
return Some((Cow::Borrowed(&input[..index]), remainder));
}
b'\\' => {
let (escaped_index, escaped_byte) = iter.next()?;
let output = output.get_or_insert_with(String::new);
output.push_str(&input[segment_start..index]);
output.push(*escaped_byte as char);
segment_start = escaped_index + 1;
}
_ => {}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_empty_bearer_challenge() {
let challenge = WwwAuthenticate::parse("Bearer").expect("should parse Bearer scheme");
assert_eq!(challenge, WwwAuthenticate::default());
}
#[test]
fn rejects_non_bearer_scheme() {
assert!(WwwAuthenticate::parse("Basic realm=\"example\"").is_none());
assert!(WwwAuthenticate::parse("Digest realm=\"example\"").is_none());
}
#[test]
fn parses_known_parameters_with_quoted_strings_and_tokens() {
let challenge = WwwAuthenticate::parse(
"Bearer realm=\"example\", scope=\"read write\", error=invalid_token, error_description=\"The access token expired\"",
)
.expect("should parse");
assert_eq!(
challenge,
WwwAuthenticate {
realm: Some(Cow::Borrowed("example")),
scope: Some(Cow::Borrowed("read write")),
error: Some(Cow::Borrowed("invalid_token")),
error_description: Some(Cow::Borrowed("The access token expired")),
..Default::default()
}
);
}
#[test]
fn quoted_string_allows_commas_and_backslash_escapes() {
let challenge = WwwAuthenticate::parse(
"Bearer error_description=\"contains, comma and a quote: \\\" and a backslash: \\\\\"",
)
.expect("should parse");
assert_eq!(
challenge,
WwwAuthenticate {
error_description: Some(Cow::Owned(
"contains, comma and a quote: \" and a backslash: \\".to_string()
)),
..Default::default()
}
);
}
#[test]
fn ignores_unknown_extension_parameters() {
let challenge =
WwwAuthenticate::parse("Bearer realm=\"example\", foo=\"bar\"").expect("should parse");
assert_eq!(
challenge,
WwwAuthenticate {
realm: Some(Cow::Borrowed("example")),
..Default::default()
}
);
}
#[test]
fn stops_at_subsequent_bearer_challenge_in_combined_header_value() {
let challenge = WwwAuthenticate::parse(
"Bearer realm=\"first\", error=\"invalid_token\", Bearer realm=\"second\"",
)
.expect("should parse");
assert_eq!(
challenge,
WwwAuthenticate {
realm: Some(Cow::Borrowed("first")),
error: Some(Cow::Borrowed("invalid_token")),
..Default::default()
}
);
}
#[test]
fn returns_none_on_invalid_trailing_garbage() {
assert!(WwwAuthenticate::parse("Bearer realm=\"example\" garbage").is_none());
}
#[test]
fn returns_none_on_missing_equals() {
assert!(WwwAuthenticate::parse("Bearer realm \"example\"").is_none());
}
#[test]
fn returns_none_on_unterminated_quoted_string() {
assert!(WwwAuthenticate::parse("Bearer realm=\"example").is_none());
}
}