diff --git a/.vscode/settings.json b/.vscode/settings.json index a126254..60c622c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,9 +2,13 @@ "cSpell.words": [ "bufbuild", "codegen", + "impls", + "pbjson", + "prost", "proto", "protobuf", "protoc", + "protos", "serde", "Thilenius", "typecheck" diff --git a/axum-connect-build/Cargo.toml b/axum-connect-build/Cargo.toml index 7f3031e..f4ce2a6 100644 --- a/axum-connect-build/Cargo.toml +++ b/axum-connect-build/Cargo.toml @@ -12,6 +12,10 @@ repository = "https://github.com/AThilenius/axum-connect" [dependencies] anyhow = "1.0" convert_case = "0.6.0" -protobuf = "3.2.0" -protobuf-codegen = "3.2.0" -protobuf-parse = "3.2.0" +pbjson-build = "0.5.1" +proc-macro2 = "1.0.56" +prost = "0.11.9" +prost-build = "0.11.9" +prost-reflect = "0.11.4" +quote = "1.0.26" +syn = "2.0.15" diff --git a/axum-connect-build/src/gen.rs b/axum-connect-build/src/gen.rs new file mode 100644 index 0000000..a878e41 --- /dev/null +++ b/axum-connect-build/src/gen.rs @@ -0,0 +1,77 @@ +use proc_macro2::TokenStream; +use prost_build::{Method, Service, ServiceGenerator}; +use quote::{format_ident, quote}; +use syn::parse_str; + +#[derive(Default)] +pub struct AxumConnectServiceGenerator {} + +impl AxumConnectServiceGenerator { + pub fn new() -> Self { + Default::default() + } + + fn generate_service(&mut self, service: Service, buf: &mut String) { + // Service struct + let service_name = format_ident!("{}", service.name); + let methods = service.methods.into_iter().map(|m| { + self.generate_service_method(m, &format!("{}.{}", service.package, service.proto_name)) + }); + + buf.push_str( + quote! { + pub struct #service_name; + + impl #service_name { + #(#methods)* + } + } + .to_string() + .as_str(), + ); + } + + fn generate_service_method(&mut self, method: Method, path_root: &str) -> TokenStream { + let method_name = format_ident!("{}", method.name); + let input_type: syn::Type = parse_str(&method.input_type).unwrap(); + let output_type: syn::Type = parse_str(&method.output_type).unwrap(); + let path = format!("/{}/{}", path_root, method.proto_name); + + quote! { + pub fn #method_name( + handler: H + ) -> impl FnOnce(axum::Router) -> axum_connect::router::RpcRouter + where + H: axum_connect::handler::HandlerFuture<#input_type, #output_type, R, T, S, B>, + T: 'static, + S: Clone + Send + Sync + 'static, + B: axum::body::HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, + { + use axum::response::IntoResponse; + + move |router: axum::Router| { + router.route( + #path, + axum::routing::post(|axum::extract::State(state): axum::extract::State, request: axum::http::Request| async move { + let res = handler.call(request, state).await; + res.into_response() + }), + ) + } + } + } + } +} + +impl ServiceGenerator for AxumConnectServiceGenerator { + fn generate(&mut self, service: Service, buf: &mut String) { + self.generate_service(service, buf); + } + + fn finalize(&mut self, buf: &mut String) { + // Add serde import (because that's less effort than hacking pbjson). + buf.push_str("\nuse axum_connect::serde;\n"); + } +} diff --git a/axum-connect-build/src/lib.rs b/axum-connect-build/src/lib.rs index 35657fb..3212f23 100644 --- a/axum-connect-build/src/lib.rs +++ b/axum-connect-build/src/lib.rs @@ -1,205 +1,44 @@ -use std::{collections::HashMap, path::Path}; +use std::{ + env, + io::{BufWriter, Write}, + path::{Path, PathBuf}, +}; -use convert_case::{Case, Casing}; -use protobuf::reflect::{FileDescriptor, MessageDescriptor}; -use protobuf_parse::Parser; +use gen::AxumConnectServiceGenerator; + +mod gen; -// TODO There is certainly a much easier way to do this, but I can't make sense of rust-protobuf. pub fn axum_connect_codegen( - include: impl AsRef, + include: &[impl AsRef], inputs: &[impl AsRef], ) -> anyhow::Result<()> { - protobuf_codegen::Codegen::new() - .pure() - .cargo_out_dir("connect_proto_gen") - .inputs(inputs) - .include(&include) - .run()?; + let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin"); - let mut parser = Parser::new(); - parser.pure(); - parser.inputs(inputs); - parser.include(include); + let mut conf = prost_build::Config::new(); + conf.compile_well_known_types(); + conf.file_descriptor_set_path(&descriptor_path); + conf.extern_path(".google.protobuf", "::pbjson_types"); + conf.service_generator(Box::new(AxumConnectServiceGenerator::new())); + conf.compile_protos(inputs, include).unwrap(); - let parsed = parser - .parse_and_typecheck() - .expect("parse and typecheck the protobuf files"); + // Use pbjson to generate the Serde impls, and inline them with the Prost files. + let descriptor_set = std::fs::read(descriptor_path)?; + let mut output: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap()); + output.push("FILENAME"); - let file_descriptors = FileDescriptor::new_dynamic_fds(parsed.file_descriptors.clone(), &[])?; + let writers = pbjson_build::Builder::new() + .register_descriptors(&descriptor_set)? + .generate(&["."], move |package| { + output.set_file_name(format!("{}.rs", package)); - // Create a list of relative paths and their corresponding file descriptors. - let relative_paths = parsed - .relative_paths - .iter() - .map(|p| { - ( - file_descriptors - .iter() - .find(|&fd| fd.name().ends_with(p.to_str())) - .expect(&format!( - "find a file descriptor matching the relative path {}", - p.to_str() - )) - .clone(), - p.to_path() - .file_stem() - .unwrap() - .to_str() - .unwrap() - .to_string(), - ) - }) - .collect(); + let file = std::fs::OpenOptions::new().append(true).open(&output)?; - // Flat map the full proto names to their respective rust type name. - let message_names = map_names(relative_paths); + Ok(BufWriter::new(file)) + })?; - for path in parsed.relative_paths { - // Find the relative file descriptor - let file_descriptor = parsed - .file_descriptors - .iter() - .find(|&fd| fd.name.clone().unwrap_or_default().ends_with(path.to_str())) - .expect(&format!( - "find a file descriptor matching the relative path {}", - path.to_str() - )); - - // TODO: This seems fragile. - let path = path.to_path().with_extension("rs"); - let cargo_out_dir = std::env::var("OUT_DIR")?; - let out_dir = Path::new(&cargo_out_dir).join("connect_proto_gen"); - let proto_rs_file_name = path.file_name().unwrap().to_str().unwrap(); - let proto_rs_full_path = out_dir.join(&proto_rs_file_name); - - // Replace all instances of "::protobuf::" with "::axum_connect::protobuf::" in the original - // generated file. - let rust = std::fs::read_to_string(&proto_rs_full_path)?; - let rust = rust.replace("::protobuf::", "::axum_connect::protobuf::"); - // std::fs::write(&proto_rs_full_path, rust)?; - - // Build up the service implementation file source. - let mut c = String::new(); - - c.push_str(FILE_PREAMBLE_TEMPLATE); - - for service in &file_descriptor.service { - // Build up methods first - let mut m = String::new(); - - for method in &service.method { - let input_type = message_names.get(method.input_type()).unwrap(); - let output_type = message_names.get(method.output_type()).unwrap(); - - m.push_str( - &METHOD_TEMPLATE - .replace("@@METHOD_NAME@@", &method.name().to_case(Case::Snake)) - .replace("@@INPUT_TYPE@@", &input_type) - .replace("@@OUTPUT_TYPE@@", &output_type) - .replace( - "@@ROUTE@@", - &format!( - "/{}.{}/{}", - file_descriptor.package(), - service.name(), - method.name() - ), - ), - ); - } - - c.push_str( - &SERVICE_TEMPLATE - .replace("@@SERVICE_NAME@@", service.name()) - .replace("@@SERVICE_METHODS@@", &m), - ); - } - - let mut final_file = String::new(); - final_file.push_str(&rust); - final_file.push_str(&c); - - std::fs::write(&proto_rs_full_path, &final_file)?; + for (_, mut writer) in writers { + writer.flush()?; } Ok(()) } - -const FILE_PREAMBLE_TEMPLATE: &str = "// Generated by axum-connect-build -use axum::{ - body::HttpBody, extract::State, http::Request, response::IntoResponse, routing::post, BoxError, - Router, -}; - -use axum_connect::{handler::HandlerFuture, router::RpcRouter}; -"; - -const SERVICE_TEMPLATE: &str = " -pub struct @@SERVICE_NAME@@; - -impl @@SERVICE_NAME@@ { -@@SERVICE_METHODS@@ -}"; - -const METHOD_TEMPLATE: &str = " - pub fn @@METHOD_NAME@@(handler: H) -> impl FnOnce(Router) -> RpcRouter - where - H: HandlerFuture<@@INPUT_TYPE@@, @@OUTPUT_TYPE@@, R, T, S, B>, - T: 'static, - S: Clone + Send + Sync + 'static, - B: HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, - { - move |router: Router| { - router.route( - \"@@ROUTE@@\", - post(|State(state): State, request: Request| async move { - let res = handler.call(request, state).await; - res.into_response() - }), - ) - } - } -"; - -/// Takes a vec to FileDescriptor and returns a flattened map of the full name for each message type -/// to the associated Rust name relative to `super`. -fn map_names(file_descriptors: Vec<(FileDescriptor, String)>) -> HashMap { - let mut map = HashMap::new(); - - for (file_descriptor, path) in file_descriptors { - for message in file_descriptor.messages() { - collect_messages_recursive(&mut map, message, format!("super::{}", path)); - } - } - - map -} - -fn collect_messages_recursive( - map: &mut HashMap, - message: MessageDescriptor, - parent_rust_path: String, -) { - map.insert( - format!(".{}", message.full_name()), - format!( - "{}::{}", - parent_rust_path, - message.name().to_case(Case::UpperCamel) - ), - ); - - for nested in message.nested_messages() { - collect_messages_recursive( - map, - nested, - format!( - "{}::{}", - parent_rust_path, - message.name().to_case(Case::Snake) - ), - ); - } -} diff --git a/axum-connect-examples/Cargo.toml b/axum-connect-examples/Cargo.toml index dd681be..5f933e0 100644 --- a/axum-connect-examples/Cargo.toml +++ b/axum-connect-examples/Cargo.toml @@ -1,11 +1,12 @@ [package] -name = "hello_world" +name = "axum-connect-example" version = "0.1.0" edition = "2021" [dependencies] axum = "0.6.9" axum-connect = { path = "../axum-connect" } +prost = "0.11.9" tokio = { version = "1.0", features = ["full"] } [build-dependencies] diff --git a/axum-connect-examples/build.rs b/axum-connect-examples/build.rs index 653de51..71002bb 100644 --- a/axum-connect-examples/build.rs +++ b/axum-connect-examples/build.rs @@ -1,5 +1,5 @@ use axum_connect_build::axum_connect_codegen; fn main() { - axum_connect_codegen("proto", &["proto/hello.proto"]).unwrap(); + axum_connect_codegen(&["proto"], &["proto/hello.proto"]).unwrap(); } diff --git a/axum-connect-examples/proto/hello.proto b/axum-connect-examples/proto/hello.proto index 36689ed..7004699 100644 --- a/axum-connect-examples/proto/hello.proto +++ b/axum-connect-examples/proto/hello.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package blink.hello; +package hello; message HelloRequest { string name = 1; } diff --git a/axum-connect-examples/src/main.rs b/axum-connect-examples/src/main.rs index 41aa4df..630b245 100644 --- a/axum-connect-examples/src/main.rs +++ b/axum-connect-examples/src/main.rs @@ -2,10 +2,12 @@ use std::net::SocketAddr; use axum::{extract::Host, Router}; use axum_connect::prelude::*; -use proto::hello::{HelloRequest, HelloResponse, HelloWorldService}; +use proto::hello::*; mod proto { - include!(concat!(env!("OUT_DIR"), "/connect_proto_gen/mod.rs")); + pub mod hello { + include!(concat!(env!("OUT_DIR"), "/hello.rs")); + } } #[tokio::main] @@ -30,6 +32,5 @@ async fn say_hello_success(Host(host): Host, request: HelloRequest) -> HelloResp "Hello {}! You're addressing the hostname: {}.", request.name, host ), - special_fields: Default::default(), } } diff --git a/axum-connect/Cargo.toml b/axum-connect/Cargo.toml index ccdc85c..5a04c06 100644 --- a/axum-connect/Cargo.toml +++ b/axum-connect/Cargo.toml @@ -17,7 +17,6 @@ repository = "https://github.com/AThilenius/axum-connect" async-trait = "0.1.64" axum = "0.6.9" futures = "0.3.26" -protobuf = "3.2.0" -protobuf-json-mapping = "3.2.0" +prost = "0.11.9" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/axum-connect/src/error.rs b/axum-connect/src/error.rs index de30adc..cb709d4 100644 --- a/axum-connect/src/error.rs +++ b/axum-connect/src/error.rs @@ -2,7 +2,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; -use protobuf::MessageFull; +use prost::Message; use serde::Serialize; use crate::{prelude::RpcResponse, response::RpcIntoResponse}; @@ -103,7 +103,7 @@ impl From for StatusCode { impl RpcIntoResponse for RpcErrorCode where - T: MessageFull, + T: Message, { fn rpc_into_response(self) -> RpcResponse { RpcResponse { @@ -115,7 +115,7 @@ where impl RpcIntoResponse for RpcError where - T: MessageFull, + T: Message, { fn rpc_into_response(self) -> RpcResponse { RpcResponse { diff --git a/axum-connect/src/handler.rs b/axum-connect/src/handler.rs index 56bced7..73008f7 100644 --- a/axum-connect/src/handler.rs +++ b/axum-connect/src/handler.rs @@ -2,10 +2,8 @@ use std::pin::Pin; use axum::{body::HttpBody, extract::FromRequest, http::Request, BoxError}; use futures::Future; -use protobuf::MessageFull; - -pub use protobuf; -pub use protobuf_json_mapping; +use prost::Message; +use serde::{de::DeserializeOwned, Serialize}; pub use crate::{error::RpcIntoError, parts::RpcFromRequestParts, response::RpcIntoResponse}; use crate::{ @@ -19,14 +17,13 @@ pub trait HandlerFuture: Clone + Send + Sized + 'stati fn call(self, req: Request, state: S) -> Self::Future; } -// This is a single expanded version of the macro below. It's left here for ease of reading and -// understanding the macro, as well as development. -// ```rust +// This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify +// the below macro. // #[allow(unused_parens, non_snake_case, unused_mut)] // impl HandlerFuture for F // where -// TReq: MessageFull + Send + 'static, -// TRes: MessageFull + Send + 'static, +// TReq: Message + DeserializeOwned + Send + 'static, +// TRes: Message + Serialize + Send + 'static, // Res: RpcIntoResponse, // F: FnOnce(T1, TReq) -> Fut + Clone + Send + 'static, // Fut: Future + Send, @@ -58,7 +55,7 @@ pub trait HandlerFuture: Clone + Send + Sized + 'stati // } // }; -// let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) { +// let proto_req: TReq = match serde_json::from_str(&body) { // Ok(value) => value, // Err(_e) => { // return RpcError::new( @@ -75,7 +72,7 @@ pub trait HandlerFuture: Clone + Send + Sized + 'stati // }) // } // } -// ``` + macro_rules! impl_handler { ( [$($ty:ident),*] @@ -84,8 +81,8 @@ macro_rules! impl_handler { impl HandlerFuture for F where - TReq: MessageFull + Send + 'static, - TRes: MessageFull + Send + 'static, + TReq: Message + DeserializeOwned + Send + 'static, + TRes: Message + Serialize + Send + 'static, Res: RpcIntoResponse, F: FnOnce($($ty,)* TReq) -> Fut + Clone + Send + 'static, Fut: Future + Send, @@ -119,7 +116,7 @@ macro_rules! impl_handler { } }; - let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) { + let proto_req: TReq = match serde_json::from_str(&body) { Ok(value) => value, Err(_e) => { return RpcError::new( diff --git a/axum-connect/src/lib.rs b/axum-connect/src/lib.rs index 3c53c7e..0b5c9ce 100644 --- a/axum-connect/src/lib.rs +++ b/axum-connect/src/lib.rs @@ -1,13 +1,13 @@ -// Re-export protobuf and protobuf_json_mapping for downstream use. -pub use protobuf; -pub use protobuf_json_mapping; - pub mod error; pub mod handler; pub mod parts; pub mod response; pub mod router; +// Re-export both prost and serde. +pub use prost; +pub use serde; + pub mod prelude { pub use crate::error::*; pub use crate::parts::*; diff --git a/axum-connect/src/parts.rs b/axum-connect/src/parts.rs index d25bdda..0d733b3 100644 --- a/axum-connect/src/parts.rs +++ b/axum-connect/src/parts.rs @@ -6,7 +6,7 @@ use axum::{ http::{self}, Extension, }; -use protobuf::MessageFull; +use prost::Message; use serde::de::DeserializeOwned; use crate::error::{RpcError, RpcErrorCode, RpcIntoError}; @@ -14,7 +14,7 @@ use crate::error::{RpcError, RpcErrorCode, RpcIntoError}; #[async_trait] pub trait RpcFromRequestParts: Sized where - T: MessageFull, + T: Message, S: Send + Sync, { /// If the extractor fails it'll use this "rejection" type. A rejection is @@ -31,7 +31,7 @@ where #[async_trait] impl RpcFromRequestParts for Host where - M: MessageFull, + M: Message, S: Send + Sync, { type Rejection = RpcError; @@ -49,7 +49,7 @@ where #[async_trait] impl RpcFromRequestParts for Query where - M: MessageFull, + M: Message, S: Send + Sync, T: DeserializeOwned, { @@ -68,7 +68,7 @@ where #[async_trait] impl RpcFromRequestParts for ConnectInfo where - M: MessageFull, + M: Message, S: Send + Sync, T: Clone + Send + Sync + 'static, { @@ -91,7 +91,7 @@ where #[async_trait] impl RpcFromRequestParts for State where - M: MessageFull, + M: Message, InnerState: FromRef, OuterState: Send + Sync, { diff --git a/axum-connect/src/response.rs b/axum-connect/src/response.rs index 080bd45..1c4ceb7 100644 --- a/axum-connect/src/response.rs +++ b/axum-connect/src/response.rs @@ -1,5 +1,6 @@ use axum::response::{IntoResponse, Response}; -use protobuf::MessageFull; +use prost::Message; +use serde::Serialize; use crate::error::{RpcError, RpcErrorCode, RpcIntoError}; @@ -12,12 +13,12 @@ pub struct RpcResponse { impl IntoResponse for RpcResponse where - T: MessageFull, + T: Message + Serialize, { fn into_response(self) -> Response { let rpc_call_response: Response = { match self.response { - Ok(value) => protobuf_json_mapping::print_to_string(&value) + Ok(value) => serde_json::to_string(&value) .map_err(|_e| { RpcError::new( RpcErrorCode::Internal, @@ -43,14 +44,14 @@ where pub trait RpcIntoResponse: Send + Sync + 'static where - T: MessageFull, + T: Message, { fn rpc_into_response(self) -> RpcResponse; } impl RpcIntoResponse for T where - T: MessageFull, + T: Message + 'static, { fn rpc_into_response(self) -> RpcResponse { RpcResponse { @@ -62,7 +63,7 @@ where impl RpcIntoResponse for Result where - T: MessageFull, + T: Message + 'static, E: RpcIntoError + Send + Sync + 'static, { fn rpc_into_response(self) -> RpcResponse {