summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGalen Guyer <galen@galenguyer.com>2022-05-30 11:35:00 -0400
committerGalen Guyer <galen@galenguyer.com>2022-05-30 11:35:00 -0400
commitc1beee3bd01615753ded3d2bbc4c4caf5a5f7ff0 (patch)
tree244e996128464e8f9061a22d8b25d36a4b75509d
parent678ce8e205dedb4c4cd14f5ff7df739aea1da0c8 (diff)
add clientip extractor for potential future shenanigans
-rw-r--r--src/extractors.rs45
-rw-r--r--src/main.rs11
-rw-r--r--src/routes/v1/records.rs17
-rw-r--r--src/routes/v1/users.rs3
4 files changed, 59 insertions, 17 deletions
diff --git a/src/extractors.rs b/src/extractors.rs
index dd294b6..9f77794 100644
--- a/src/extractors.rs
+++ b/src/extractors.rs
@@ -1,8 +1,9 @@
use axum::async_trait;
+use axum::extract::ConnectInfo;
use axum::extract::FromRequest;
use axum::extract::{rejection::JsonRejection, RequestParts};
use axum::http::header::{self, HeaderValue};
-use axum::http::StatusCode;
+use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::BoxError;
use hmac::{Hmac, Mac};
@@ -17,6 +18,8 @@ use sqlx::types::Uuid;
use std::collections::BTreeMap;
use std::env;
use std::error::Error;
+use std::net::IpAddr;
+use std::net::SocketAddr;
lazy_static! {
static ref JWT_SECRET: String = env::var("JWT_SECRET").unwrap();
@@ -175,3 +178,43 @@ where
None
}
}
+
+pub struct ClientIp(pub Option<IpAddr>);
+
+#[async_trait]
+impl<B> FromRequest<B> for ClientIp
+where
+ B: Send,
+{
+ type Rejection = (StatusCode, &'static str);
+
+ async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ let headers = req.headers();
+ Ok(ClientIp(
+ maybe_x_forwarded_for(headers)
+ .or_else(|| maybe_x_real_ip(headers))
+ .or_else(|| maybe_connect_info(req)),
+ ))
+ }
+}
+
+fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
+ headers
+ .get("X-Forwarded-For")
+ .and_then(|value| value.to_str().ok())
+ .and_then(|value| value.split(',').next())
+ .and_then(|value| value.trim().parse().ok())
+}
+
+fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
+ headers
+ .get("X-Real-Ip")
+ .and_then(|value| value.to_str().ok())
+ .and_then(|value| value.parse().ok())
+}
+
+fn maybe_connect_info<B: Send>(req: &RequestParts<B>) -> Option<IpAddr> {
+ req.extensions()
+ .get::<ConnectInfo<SocketAddr>>()
+ .map(|ConnectInfo(addr)| addr.ip())
+}
diff --git a/src/main.rs b/src/main.rs
index a1fb106..2ac7afd 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,5 @@
use axum::extract::Extension;
use axum::{
- response::IntoResponse,
routing::{get, post, put},
Router, Server,
};
@@ -23,7 +22,10 @@ async fn main() {
dotenv().ok();
if std::env::args().nth(1) == Some("--version".to_string()) {
- println!("{}", option_env!("CARGO_PKG_VERSION").unwrap_or_else(|| "unknown"));
+ println!(
+ "{}",
+ option_env!("CARGO_PKG_VERSION").unwrap_or_else(|| "unknown")
+ );
return;
}
@@ -51,7 +53,6 @@ async fn main() {
Router::new().nest(
"/v1",
Router::new()
- .route("/", get(root))
.nest(
"/users",
Router::new()
@@ -89,9 +90,7 @@ async fn main() {
info!("Binding to {addr}");
Server::bind(&addr)
- .serve(app.into_make_service())
+ .serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap();
}
-
-async fn root() -> impl IntoResponse {}
diff --git a/src/routes/v1/records.rs b/src/routes/v1/records.rs
index d79e644..fd72663 100644
--- a/src/routes/v1/records.rs
+++ b/src/routes/v1/records.rs
@@ -7,8 +7,8 @@ use axum::response::IntoResponse;
use axum::Extension;
use serde_json::json;
use sqlx::{Pool, Postgres};
-use uuid::Uuid;
use std::sync::Arc;
+use uuid::Uuid;
pub async fn get_records(
Path(id): Path<String>,
@@ -166,12 +166,8 @@ pub async fn delete_record(
// TODO: Make sure record exists
// TODO: Check to make sure record is within zone
- let result = db::records::delete_record(
- &pool,
- &zone_id,
- &Uuid::parse_str(&record_id).unwrap(),
- )
- .await;
+ let result =
+ db::records::delete_record(&pool, &zone_id, &Uuid::parse_str(&record_id).unwrap()).await;
if result.is_err() {
return (
StatusCode::INTERNAL_SERVER_ERROR,
@@ -179,5 +175,10 @@ pub async fn delete_record(
);
}
- (StatusCode::OK, Json(json!({"message": format!("Record {} deleted", record_id)})))
+ (
+ StatusCode::OK,
+ Json(json!({
+ "message": format!("Record {} deleted", record_id)
+ })),
+ )
}
diff --git a/src/routes/v1/users.rs b/src/routes/v1/users.rs
index 00a8b84..10afc0c 100644
--- a/src/routes/v1/users.rs
+++ b/src/routes/v1/users.rs
@@ -155,6 +155,5 @@ fn issue_jwt(user: User) -> String {
claims.insert("email", &user.email);
claims.insert("admin", &admin);
- let token = claims.sign_with_key(&key).unwrap();
- token
+ claims.sign_with_key(&key).unwrap()
}