scufflecloud_core/middleware/
auth.rs

1use std::str::FromStr;
2use std::sync::Arc;
3
4use axum::extract::Request;
5use axum::http::{HeaderMap, HeaderName, StatusCode};
6use axum::middleware::Next;
7use axum::response::Response;
8use base64::Engine;
9use core_db_types::models::{UserSession, UserSessionTokenId};
10use core_db_types::schema::user_sessions;
11use diesel::{BoolExpressionMethods, ExpressionMethods, SelectableHelper};
12use diesel_async::RunQueryDsl;
13use ext_traits::RequestExt;
14use fred::prelude::KeysInterface;
15use geo_ip::GeoIpRequestExt;
16use geo_ip::middleware::IpAddressInfo;
17use hmac::Mac;
18
19const TOKEN_ID_HEADER: HeaderName = HeaderName::from_static("scuf-token-id");
20const TIMESTAMP_HEADER: HeaderName = HeaderName::from_static("scuf-timestamp");
21const NONCE_HEADER: HeaderName = HeaderName::from_static("scuf-nonce");
22
23const AUTHENTICATION_METHOD_HEADER: HeaderName = HeaderName::from_static("scuf-auth-method");
24const AUTHENTICATION_HMAC_HEADER: HeaderName = HeaderName::from_static("scuf-auth-hmac");
25
26const USER_AGENT_HEADER: HeaderName = HeaderName::from_static("user-agent");
27
28pub(crate) const fn auth_headers() -> [HeaderName; 5] {
29    [
30        TOKEN_ID_HEADER,
31        TIMESTAMP_HEADER,
32        NONCE_HEADER,
33        AUTHENTICATION_METHOD_HEADER,
34        AUTHENTICATION_HMAC_HEADER,
35    ]
36}
37
38#[derive(Clone, Debug)]
39pub(crate) struct ExpiredSession(pub UserSession);
40
41pub(crate) async fn auth<G: core_traits::Global>(mut req: Request, next: Next) -> Result<Response, StatusCode> {
42    let global = req
43        .extensions()
44        .global::<G>()
45        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
46    let ip_info = req
47        .extensions()
48        .ip_address_info()
49        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
50
51    let (session, expired_session) = get_and_update_active_session(&global, &ip_info, req.headers()).await?;
52    if let Some(session) = session {
53        req.extensions_mut().insert(session);
54    }
55    if let Some(expired_session) = expired_session {
56        req.extensions_mut().insert(expired_session);
57    }
58
59    Ok(next.run(req).await)
60}
61
62fn get_header_value<'a, T>(headers: &'a HeaderMap, header_name: &HeaderName) -> Result<Option<T>, StatusCode>
63where
64    T: FromStr + 'a,
65    T::Err: std::fmt::Display,
66{
67    match headers.get(header_name) {
68        Some(h) => {
69            let s = h.to_str().map_err(|e| {
70                tracing::debug!(header = %header_name, error = %e, "invalid header value");
71                StatusCode::BAD_REQUEST
72            })?;
73            Ok(Some(s.parse().map_err(|e| {
74                tracing::debug!(header = %header_name, error = %e, "failed to parse header value");
75                StatusCode::BAD_REQUEST
76            })?))
77        }
78        None => Ok(None),
79    }
80}
81
82#[derive(Debug, thiserror::Error)]
83enum AuthenticationMethodParseError {
84    #[error("unknown authentication algorithm")]
85    UnknownAlgorithm,
86    #[error("invalid header format")]
87    InvalidHeaderFormat,
88}
89
90#[derive(Debug)]
91enum AuthenticationAlgorithm {
92    HmacSha256,
93}
94
95impl FromStr for AuthenticationAlgorithm {
96    type Err = AuthenticationMethodParseError;
97
98    fn from_str(s: &str) -> Result<Self, Self::Err> {
99        match s {
100            "HMAC-SHA256" => Ok(AuthenticationAlgorithm::HmacSha256),
101            _ => Err(AuthenticationMethodParseError::UnknownAlgorithm),
102        }
103    }
104}
105
106#[derive(Debug)]
107struct AuthenticationMethod {
108    pub algorithm: AuthenticationAlgorithm,
109    pub headers: Vec<HeaderName>,
110}
111
112impl FromStr for AuthenticationMethod {
113    type Err = AuthenticationMethodParseError;
114
115    fn from_str(s: &str) -> Result<Self, Self::Err> {
116        let parts: Vec<&str> = s.splitn(2, ';').collect();
117        if parts.len() != 2 {
118            return Err(AuthenticationMethodParseError::InvalidHeaderFormat);
119        }
120
121        let algorithm: AuthenticationAlgorithm = parts[0].parse()?;
122        let headers: Vec<HeaderName> = parts[1]
123            .split(',')
124            .map(|h| HeaderName::from_str(h.trim()).map_err(|_| AuthenticationMethodParseError::InvalidHeaderFormat))
125            .collect::<Result<_, _>>()?;
126
127        Ok(AuthenticationMethod { algorithm, headers })
128    }
129}
130
131#[derive(thiserror::Error, Debug)]
132enum NonceParseError {
133    #[error("failed to decode: {0}")]
134    Base64(#[from] base64::DecodeError),
135    #[error("invalid nonce length {0}, must be 32 bytes")]
136    InvalidLength(usize),
137}
138
139#[derive(Debug)]
140struct Nonce(Vec<u8>);
141
142impl FromStr for Nonce {
143    type Err = NonceParseError;
144
145    fn from_str(s: &str) -> Result<Self, Self::Err> {
146        let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
147        if bytes.len() != 32 {
148            return Err(NonceParseError::InvalidLength(bytes.len()));
149        }
150        Ok(Nonce(bytes))
151    }
152}
153
154#[derive(Debug)]
155struct AuthenticationHmac(Vec<u8>);
156
157impl FromStr for AuthenticationHmac {
158    type Err = base64::DecodeError;
159
160    fn from_str(s: &str) -> Result<Self, Self::Err> {
161        let bytes = base64::prelude::BASE64_STANDARD.decode(s)?;
162        Ok(AuthenticationHmac(bytes))
163    }
164}
165
166async fn get_and_update_active_session<G: core_traits::Global>(
167    global: &Arc<G>,
168    ip_info: &IpAddressInfo,
169    headers: &HeaderMap,
170) -> Result<(Option<UserSession>, Option<ExpiredSession>), StatusCode> {
171    let Some(session_token_id) = get_header_value::<UserSessionTokenId>(headers, &TOKEN_ID_HEADER)? else {
172        return Ok((None, None));
173    };
174    let Some(timestamp) =
175        get_header_value::<u64>(headers, &TIMESTAMP_HEADER)?.and_then(|t| chrono::DateTime::from_timestamp_millis(t as i64))
176    else {
177        return Ok((None, None));
178    };
179    let Some(nonce) = get_header_value::<Nonce>(headers, &NONCE_HEADER)? else {
180        return Ok((None, None));
181    };
182
183    let Some(auth_method) = get_header_value::<AuthenticationMethod>(headers, &AUTHENTICATION_METHOD_HEADER)? else {
184        return Ok((None, None));
185    };
186    let Some(auth_hmac) = get_header_value::<AuthenticationHmac>(headers, &AUTHENTICATION_HMAC_HEADER)? else {
187        return Ok((None, None));
188    };
189
190    if (chrono::Utc::now() - timestamp).abs()
191        > chrono::TimeDelta::from_std(global.timeout_config().max_request_diff).expect("invalid config")
192    {
193        tracing::debug!(timestamp = %timestamp, "invalid request timestamp");
194        return Err(StatusCode::UNAUTHORIZED);
195    }
196
197    if !auth_method.headers.contains(&TOKEN_ID_HEADER)
198        || !auth_method.headers.contains(&TIMESTAMP_HEADER)
199        || !auth_method.headers.contains(&NONCE_HEADER)
200    {
201        tracing::debug!("missing required headers in authentication method");
202        return Err(StatusCode::BAD_REQUEST);
203    }
204
205    let last_user_agent = get_header_value::<String>(headers, &USER_AGENT_HEADER)?;
206
207    let mut db = global.db().await.map_err(|e| {
208        tracing::error!(error = %e, "failed to connect to database");
209        StatusCode::INTERNAL_SERVER_ERROR
210    })?;
211
212    let Some(session) = diesel::update(user_sessions::dsl::user_sessions)
213        .set((
214            user_sessions::dsl::last_ip.eq(ip_info.to_network()),
215            user_sessions::dsl::last_user_agent.eq(last_user_agent),
216            user_sessions::dsl::last_used_at.eq(chrono::Utc::now()),
217        ))
218        .filter(
219            user_sessions::dsl::token_id
220                .eq(session_token_id)
221                .and(user_sessions::dsl::token.is_not_null())
222                .and(user_sessions::dsl::expires_at.gt(chrono::Utc::now())),
223        )
224        .returning(UserSession::as_select())
225        .get_results::<UserSession>(&mut db)
226        .await
227        .map_err(|e| {
228            tracing::error!(error = %e, "failed to update user session");
229            StatusCode::INTERNAL_SERVER_ERROR
230        })?
231        .into_iter()
232        .next()
233    else {
234        tracing::debug!(token_id = %session_token_id, "no active session found");
235        return Err(StatusCode::UNAUTHORIZED);
236    };
237
238    let token = session.token.as_ref().expect("known to be not null due to filter");
239
240    // Verify HMAC
241    match auth_method.algorithm {
242        AuthenticationAlgorithm::HmacSha256 => {
243            let mut mac = hmac::Hmac::<sha2::Sha256>::new_from_slice(token).map_err(|e| {
244                tracing::error!(error = %e, "failed to create HMAC instance");
245                StatusCode::INTERNAL_SERVER_ERROR
246            })?;
247
248            for header_name in &auth_method.headers {
249                if let Some(value) = headers.get(header_name) {
250                    mac.update(value.as_bytes());
251                } else {
252                    tracing::debug!(header = %header_name, "missing header");
253                    return Err(StatusCode::BAD_REQUEST);
254                }
255            }
256
257            mac.verify_slice(&auth_hmac.0).map_err(|e| {
258                tracing::debug!(error = %e, "HMAC verification failed");
259                StatusCode::UNAUTHORIZED
260            })?;
261        }
262    }
263
264    let mut key = "nonces:".as_bytes().to_vec();
265    key.extend_from_slice(&nonce.0);
266    let value: Option<bool> = global
267        .redis()
268        .set(
269            key.as_slice(),
270            true,
271            Some(fred::types::Expiration::PX(
272                global.timeout_config().max_request_diff.as_millis() as i64,
273            )),
274            Some(fred::types::SetOptions::NX),
275            true,
276        )
277        .await
278        .map_err(|e| {
279            tracing::error!(error = %e, "failed to set nonce in redis");
280            StatusCode::INTERNAL_SERVER_ERROR
281        })?;
282
283    if value.is_some() {
284        tracing::debug!("replayed nonce detected");
285        return Err(StatusCode::UNAUTHORIZED);
286    }
287
288    if session.token_expires_at.is_some_and(|t| t <= chrono::Utc::now()) {
289        return Ok((None, Some(ExpiredSession(session))));
290    }
291
292    Ok((Some(session), None))
293}