mirror of
https://github.com/AThilenius/axum-connect.git
synced 2024-11-24 06:19:46 +00:00
Initial commit
This commit is contained in:
commit
ae7afcfbea
11 changed files with 473 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/target
|
||||
/Cargo.lock
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"cSpell.words": ["codegen", "proto", "protobuf", "serde"]
|
||||
}
|
3
Cargo.toml
Normal file
3
Cargo.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = ["axum-connect", "axum-connect-build", "axum-connect-examples"]
|
11
axum-connect-build/Cargo.toml
Normal file
11
axum-connect-build/Cargo.toml
Normal file
|
@ -0,0 +1,11 @@
|
|||
[package]
|
||||
name = "axum-connect-build"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
convert_case = "0.6.0"
|
||||
protobuf = { git = "https://github.com/AThilenius/rust-protobuf.git" }
|
||||
protobuf-codegen = { git = "https://github.com/AThilenius/rust-protobuf.git" }
|
||||
protobuf-parse = { git = "https://github.com/AThilenius/rust-protobuf.git" }
|
153
axum-connect-build/src/lib.rs
Normal file
153
axum-connect-build/src/lib.rs
Normal file
|
@ -0,0 +1,153 @@
|
|||
use std::path::Path;
|
||||
|
||||
use convert_case::{Case, Casing};
|
||||
use protobuf::reflect::FileDescriptor;
|
||||
use protobuf_codegen::{
|
||||
gen::scope::{RootScope, WithScope},
|
||||
Codegen,
|
||||
};
|
||||
use protobuf_parse::ProtobufAbsPath;
|
||||
|
||||
// TODO There is certainly a much easier way to do this, but I can't make sense of rust-protobuf.
|
||||
pub fn axum_connect_codegen(
|
||||
include: impl AsRef<Path>,
|
||||
inputs: impl IntoIterator<Item = impl AsRef<Path>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let results = Codegen::new()
|
||||
.pure()
|
||||
.cargo_out_dir("connect_proto_gen")
|
||||
.inputs(inputs)
|
||||
.include(include)
|
||||
.run()?;
|
||||
|
||||
let file_descriptors =
|
||||
FileDescriptor::new_dynamic_fds(results.parsed.file_descriptors.clone(), &[])?;
|
||||
|
||||
let root_scope = RootScope {
|
||||
file_descriptors: &file_descriptors.as_slice(),
|
||||
};
|
||||
|
||||
for path in results.parsed.relative_paths {
|
||||
// Find the relative file descriptor
|
||||
let file_descriptor = results
|
||||
.parsed
|
||||
.file_descriptors
|
||||
.iter()
|
||||
.find(|&fd| fd.name.clone().unwrap_or_default().ends_with(path.to_str()))
|
||||
.expect(&format!(
|
||||
"find a file descriptor matching the relative path {}",
|
||||
path.to_str()
|
||||
));
|
||||
|
||||
// TODO: This seems fragile.
|
||||
let path = path.to_path().with_extension("rs");
|
||||
let cargo_out_dir = std::env::var("OUT_DIR")?;
|
||||
let out_dir = Path::new(&cargo_out_dir).join("connect_proto_gen");
|
||||
let proto_rs_file_name = path.file_name().unwrap().to_str().unwrap();
|
||||
let proto_rs_full_path = out_dir.join(&proto_rs_file_name);
|
||||
|
||||
// Replace all instances of "::protobuf::" with "::axum_connect::protobuf::" in the original
|
||||
// generated file.
|
||||
let rust = std::fs::read_to_string(&proto_rs_full_path)?;
|
||||
let rust = rust.replace("::protobuf::", "::axum_connect::protobuf::");
|
||||
// std::fs::write(&proto_rs_full_path, rust)?;
|
||||
|
||||
// Build up the service implementation file source.
|
||||
let mut c = String::new();
|
||||
|
||||
c.push_str(FILE_PREAMBLE_TEMPLATE);
|
||||
|
||||
for service in &file_descriptor.service {
|
||||
// Build up methods first
|
||||
let mut m = String::new();
|
||||
|
||||
for method in &service.method {
|
||||
let input_type = root_scope
|
||||
.find_message(&ProtobufAbsPath {
|
||||
path: method.input_type().to_string(),
|
||||
})
|
||||
.rust_name_with_file()
|
||||
.to_path()
|
||||
.to_string();
|
||||
|
||||
let output_type = root_scope
|
||||
.find_message(&ProtobufAbsPath {
|
||||
path: method.output_type().to_string(),
|
||||
})
|
||||
.rust_name_with_file()
|
||||
.to_path()
|
||||
.to_string();
|
||||
|
||||
m.push_str(
|
||||
&METHOD_TEMPLATE
|
||||
.replace("@@METHOD_NAME@@", &method.name().to_case(Case::Snake))
|
||||
.replace("@@INPUT_TYPE@@", &input_type)
|
||||
.replace("@@OUTPUT_TYPE@@", &output_type)
|
||||
.replace(
|
||||
"@@ROUTE@@",
|
||||
&format!(
|
||||
"/{}.{}/{}",
|
||||
file_descriptor.package(),
|
||||
service.name(),
|
||||
method.name()
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
c.push_str(
|
||||
&SERVICE_TEMPLATE
|
||||
.replace("@@SERVICE_NAME@@", service.name())
|
||||
.replace("@@SERVICE_METHODS@@", &m),
|
||||
);
|
||||
}
|
||||
|
||||
let mut final_file = String::new();
|
||||
final_file.push_str(&rust);
|
||||
final_file.push_str(&c);
|
||||
|
||||
std::fs::write(&proto_rs_full_path, &final_file)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const FILE_PREAMBLE_TEMPLATE: &str = "// Generated by axum-connect-build
|
||||
use axum::{
|
||||
body::HttpBody, extract::State, http::Request, response::IntoResponse, routing::post, BoxError,
|
||||
Router,
|
||||
};
|
||||
|
||||
use axum_connect::{HandlerFuture, RpcRouter};
|
||||
";
|
||||
|
||||
const SERVICE_TEMPLATE: &str = "
|
||||
pub struct @@SERVICE_NAME@@;
|
||||
|
||||
impl @@SERVICE_NAME@@ {
|
||||
@@SERVICE_METHODS@@
|
||||
}";
|
||||
|
||||
const METHOD_TEMPLATE: &str = "
|
||||
pub fn @@METHOD_NAME@@<T, H, S, B>(handler: H) -> impl FnOnce(Router<S, B>) -> RpcRouter<S, B>
|
||||
where
|
||||
H: HandlerFuture<super::@@INPUT_TYPE@@, super::@@OUTPUT_TYPE@@, T, S, B>,
|
||||
T: 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
move |router: Router<S, B>| {
|
||||
router.route(
|
||||
\"@@ROUTE@@\",
|
||||
post(|State(state): State<S>, request: Request<B>| async move {
|
||||
let res = handler.call(request, state).await;
|
||||
::axum_connect::protobuf_json_mapping::print_to_string(&res)
|
||||
.unwrap()
|
||||
.into_response()
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
";
|
12
axum-connect-examples/Cargo.toml
Normal file
12
axum-connect-examples/Cargo.toml
Normal file
|
@ -0,0 +1,12 @@
|
|||
[package]
|
||||
name = "hello_world"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6.9"
|
||||
axum-connect = { path = "../axum-connect" }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
|
||||
[build-dependencies]
|
||||
axum-connect-build = { path = "../axum-connect-build" }
|
5
axum-connect-examples/build.rs
Normal file
5
axum-connect-examples/build.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
use axum_connect_build::axum_connect_codegen;
|
||||
|
||||
fn main() {
|
||||
axum_connect_codegen("proto", &["proto/hello.proto"]).unwrap();
|
||||
}
|
15
axum-connect-examples/proto/hello.proto
Normal file
15
axum-connect-examples/proto/hello.proto
Normal file
|
@ -0,0 +1,15 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package axum_connect.examples.hello_world;
|
||||
|
||||
message HelloRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message HelloResponse {
|
||||
string message = 1;
|
||||
}
|
||||
|
||||
service HelloWorldService {
|
||||
rpc SayHello(HelloRequest) returns (HelloResponse) {}
|
||||
}
|
33
axum-connect-examples/src/main.rs
Normal file
33
axum-connect-examples/src/main.rs
Normal file
|
@ -0,0 +1,33 @@
|
|||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{extract::Host, Router};
|
||||
use axum_connect::*;
|
||||
use proto::hello::{HelloRequest, HelloResponse, HelloWorldService};
|
||||
|
||||
mod proto {
|
||||
include!(concat!(env!("OUT_DIR"), "/connect_proto_gen/mod.rs"));
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Build our application with a route
|
||||
let app = Router::new().rpc(HelloWorldService::say_hello(say_hello_handler));
|
||||
|
||||
// Run the Axum server.
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
println!("listening on http://{}", addr);
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn say_hello_handler(Host(host): Host, request: HelloRequest) -> HelloResponse {
|
||||
HelloResponse {
|
||||
message: format!(
|
||||
"Hello {}! You're addressing the hostname: {}.",
|
||||
request.name, host
|
||||
),
|
||||
special_fields: Default::default(),
|
||||
}
|
||||
}
|
12
axum-connect/Cargo.toml
Normal file
12
axum-connect/Cargo.toml
Normal file
|
@ -0,0 +1,12 @@
|
|||
[package]
|
||||
name = "axum-connect"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6.9"
|
||||
futures = "0.3.26"
|
||||
protobuf = "3.2.0"
|
||||
protobuf-json-mapping = "3.2.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
224
axum-connect/src/lib.rs
Normal file
224
axum-connect/src/lib.rs
Normal file
|
@ -0,0 +1,224 @@
|
|||
use std::pin::Pin;
|
||||
|
||||
use axum::{
|
||||
body::{Body, HttpBody},
|
||||
extract::{FromRequest, FromRequestParts},
|
||||
http::{Request, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
BoxError, Router,
|
||||
};
|
||||
use futures::Future;
|
||||
use protobuf::MessageFull;
|
||||
use serde::Serialize;
|
||||
|
||||
pub use protobuf;
|
||||
pub use protobuf_json_mapping;
|
||||
|
||||
pub trait RpcRouterExt<S, B>: Sized {
|
||||
fn rpc<F>(self, register: F) -> Self
|
||||
where
|
||||
F: FnOnce(Self) -> RpcRouter<S, B>;
|
||||
}
|
||||
|
||||
impl<S, B> RpcRouterExt<S, B> for Router<S, B> {
|
||||
fn rpc<F>(self, register: F) -> Self
|
||||
where
|
||||
F: FnOnce(Self) -> RpcRouter<S, B>,
|
||||
{
|
||||
register(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub type RpcRouter<S, B> = Router<S, B>;
|
||||
|
||||
pub trait RegisterRpcService<S, B>: Sized {
|
||||
fn register(self, router: Router<S, B>) -> Self;
|
||||
}
|
||||
|
||||
pub trait IntoRpcResponse<T>
|
||||
where
|
||||
T: MessageFull,
|
||||
{
|
||||
fn into_response(self) -> Response;
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct RpcError {
|
||||
pub code: RpcErrorCode,
|
||||
pub message: String,
|
||||
pub details: Vec<RpcErrorDetail>,
|
||||
}
|
||||
|
||||
impl RpcError {
|
||||
pub fn new(code: RpcErrorCode, message: String) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message,
|
||||
details: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct RpcErrorDetail {
|
||||
#[serde(rename = "type")]
|
||||
pub proto_type: String,
|
||||
#[serde(rename = "value")]
|
||||
pub proto_b62_value: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RpcErrorCode {
|
||||
Canceled,
|
||||
Unknown,
|
||||
InvalidArgument,
|
||||
DeadlineExceeded,
|
||||
NotFound,
|
||||
AlreadyExists,
|
||||
PermissionDenied,
|
||||
ResourceExhausted,
|
||||
FailedPrecondition,
|
||||
Aborted,
|
||||
OutOfRange,
|
||||
Unimplemented,
|
||||
Internal,
|
||||
Unavailable,
|
||||
DataLoss,
|
||||
Unauthenticated,
|
||||
}
|
||||
|
||||
impl From<RpcErrorCode> for StatusCode {
|
||||
fn from(val: RpcErrorCode) -> Self {
|
||||
match val {
|
||||
// Spec: https://connect.build/docs/protocol/#error-codes
|
||||
RpcErrorCode::Canceled => StatusCode::REQUEST_TIMEOUT,
|
||||
RpcErrorCode::Unknown => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RpcErrorCode::InvalidArgument => StatusCode::BAD_REQUEST,
|
||||
RpcErrorCode::DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
|
||||
RpcErrorCode::NotFound => StatusCode::NOT_FOUND,
|
||||
RpcErrorCode::AlreadyExists => StatusCode::CONFLICT,
|
||||
RpcErrorCode::PermissionDenied => StatusCode::FORBIDDEN,
|
||||
RpcErrorCode::ResourceExhausted => StatusCode::TOO_MANY_REQUESTS,
|
||||
RpcErrorCode::FailedPrecondition => StatusCode::PRECONDITION_FAILED,
|
||||
RpcErrorCode::Aborted => StatusCode::CONFLICT,
|
||||
RpcErrorCode::OutOfRange => StatusCode::BAD_REQUEST,
|
||||
RpcErrorCode::Unimplemented => StatusCode::NOT_FOUND,
|
||||
RpcErrorCode::Internal => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RpcErrorCode::Unavailable => StatusCode::SERVICE_UNAVAILABLE,
|
||||
RpcErrorCode::DataLoss => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RpcErrorCode::Unauthenticated => StatusCode::UNAUTHORIZED,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for RpcError {
|
||||
fn into_response(self) -> Response {
|
||||
let status_code = StatusCode::from(self.code.clone());
|
||||
let json = serde_json::to_string(&self).expect("serialize error type");
|
||||
(status_code, json).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, E> IntoRpcResponse<T> for Result<T, E>
|
||||
where
|
||||
T: MessageFull,
|
||||
E: Into<RpcError>,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
Ok(res) => rpc_to_response(res),
|
||||
Err(err) => err.into().into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HandlerFuture<TReq, TRes, T, S, B = Body>: Clone + Send + Sized + 'static {
|
||||
type Future: Future<Output = TRes> + Send + 'static;
|
||||
|
||||
fn call(self, req: Request<B>, state: S) -> Self::Future;
|
||||
}
|
||||
|
||||
fn rpc_to_response<T>(res: T) -> Response
|
||||
where
|
||||
T: MessageFull,
|
||||
{
|
||||
protobuf_json_mapping::print_to_string(&res)
|
||||
.map_err(|_e| {
|
||||
RpcError::new(
|
||||
RpcErrorCode::Internal,
|
||||
"Failed to serialize response".to_string(),
|
||||
)
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
macro_rules! impl_handler {
|
||||
(
|
||||
[$($ty:ident),*]
|
||||
) => {
|
||||
#[allow(unused_parens, non_snake_case, unused_mut)]
|
||||
impl<TReq, TRes, F, Fut, S, B, $($ty,)*> HandlerFuture<TReq, TRes, ($($ty,)* TReq), S, B> for F
|
||||
where
|
||||
TReq: MessageFull + Send + 'static,
|
||||
TRes: MessageFull + Send + 'static,
|
||||
F: FnOnce($($ty,)* TReq) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = TRes> + Send,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
S: Send + Sync + 'static,
|
||||
$( $ty: FromRequestParts<S> + Send, )*
|
||||
{
|
||||
type Future = Pin<Box<dyn Future<Output = TRes> + Send>>;
|
||||
|
||||
fn call(self, req: Request<B>, state: S) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let state = &state;
|
||||
|
||||
// This would be done by macro expansion. It also wouldn't be unwrapped, but
|
||||
// there is no error union so I can't return a rejection.
|
||||
$(
|
||||
let $ty = match $ty::from_request_parts(&mut parts, state).await {
|
||||
Ok(value) => value,
|
||||
Err(_e) => unreachable!(),
|
||||
};
|
||||
)*
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
let body = match String::from_request(req, state).await {
|
||||
Ok(value) => value,
|
||||
Err(_e) => unreachable!(),
|
||||
};
|
||||
|
||||
let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) {
|
||||
Ok(value) => value,
|
||||
Err(_e) => unreachable!(),
|
||||
};
|
||||
|
||||
let res = self($($ty,)* proto_req).await;
|
||||
res
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_handler!([]);
|
||||
impl_handler!([T1]);
|
||||
impl_handler!([T1, T2]);
|
||||
impl_handler!([T1, T2, T3]);
|
||||
impl_handler!([T1, T2, T3, T4]);
|
||||
impl_handler!([T1, T2, T3, T4, T5]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14]);
|
||||
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15]);
|
Loading…
Reference in a new issue