Switch to prost and pbjson-build

This commit is contained in:
Alec Thilenius 2023-04-29 08:36:42 -07:00
parent 47b4b09162
commit 4c84deddc9
14 changed files with 157 additions and 234 deletions

View file

@ -2,9 +2,13 @@
"cSpell.words": [
"bufbuild",
"codegen",
"impls",
"pbjson",
"prost",
"proto",
"protobuf",
"protoc",
"protos",
"serde",
"Thilenius",
"typecheck"

View file

@ -12,6 +12,10 @@ repository = "https://github.com/AThilenius/axum-connect"
[dependencies]
anyhow = "1.0"
convert_case = "0.6.0"
protobuf = "3.2.0"
protobuf-codegen = "3.2.0"
protobuf-parse = "3.2.0"
pbjson-build = "0.5.1"
proc-macro2 = "1.0.56"
prost = "0.11.9"
prost-build = "0.11.9"
prost-reflect = "0.11.4"
quote = "1.0.26"
syn = "2.0.15"

View file

@ -0,0 +1,77 @@
use proc_macro2::TokenStream;
use prost_build::{Method, Service, ServiceGenerator};
use quote::{format_ident, quote};
use syn::parse_str;
#[derive(Default)]
pub struct AxumConnectServiceGenerator {}
impl AxumConnectServiceGenerator {
pub fn new() -> Self {
Default::default()
}
fn generate_service(&mut self, service: Service, buf: &mut String) {
// Service struct
let service_name = format_ident!("{}", service.name);
let methods = service.methods.into_iter().map(|m| {
self.generate_service_method(m, &format!("{}.{}", service.package, service.proto_name))
});
buf.push_str(
quote! {
pub struct #service_name;
impl #service_name {
#(#methods)*
}
}
.to_string()
.as_str(),
);
}
fn generate_service_method(&mut self, method: Method, path_root: &str) -> TokenStream {
let method_name = format_ident!("{}", method.name);
let input_type: syn::Type = parse_str(&method.input_type).unwrap();
let output_type: syn::Type = parse_str(&method.output_type).unwrap();
let path = format!("/{}/{}", path_root, method.proto_name);
quote! {
pub fn #method_name<T, H, R, S, B>(
handler: H
) -> impl FnOnce(axum::Router<S, B>) -> axum_connect::router::RpcRouter<S, B>
where
H: axum_connect::handler::HandlerFuture<#input_type, #output_type, R, T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
B: axum::body::HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<axum::BoxError>,
{
use axum::response::IntoResponse;
move |router: axum::Router<S, B>| {
router.route(
#path,
axum::routing::post(|axum::extract::State(state): axum::extract::State<S>, request: axum::http::Request<B>| async move {
let res = handler.call(request, state).await;
res.into_response()
}),
)
}
}
}
}
}
impl ServiceGenerator for AxumConnectServiceGenerator {
fn generate(&mut self, service: Service, buf: &mut String) {
self.generate_service(service, buf);
}
fn finalize(&mut self, buf: &mut String) {
// Add serde import (because that's less effort than hacking pbjson).
buf.push_str("\nuse axum_connect::serde;\n");
}
}

View file

@ -1,205 +1,44 @@
use std::{collections::HashMap, path::Path};
use std::{
env,
io::{BufWriter, Write},
path::{Path, PathBuf},
};
use convert_case::{Case, Casing};
use protobuf::reflect::{FileDescriptor, MessageDescriptor};
use protobuf_parse::Parser;
use gen::AxumConnectServiceGenerator;
mod gen;
// TODO There is certainly a much easier way to do this, but I can't make sense of rust-protobuf.
pub fn axum_connect_codegen(
include: impl AsRef<Path>,
include: &[impl AsRef<Path>],
inputs: &[impl AsRef<Path>],
) -> anyhow::Result<()> {
protobuf_codegen::Codegen::new()
.pure()
.cargo_out_dir("connect_proto_gen")
.inputs(inputs)
.include(&include)
.run()?;
let descriptor_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("proto_descriptor.bin");
let mut parser = Parser::new();
parser.pure();
parser.inputs(inputs);
parser.include(include);
let mut conf = prost_build::Config::new();
conf.compile_well_known_types();
conf.file_descriptor_set_path(&descriptor_path);
conf.extern_path(".google.protobuf", "::pbjson_types");
conf.service_generator(Box::new(AxumConnectServiceGenerator::new()));
conf.compile_protos(inputs, include).unwrap();
let parsed = parser
.parse_and_typecheck()
.expect("parse and typecheck the protobuf files");
// Use pbjson to generate the Serde impls, and inline them with the Prost files.
let descriptor_set = std::fs::read(descriptor_path)?;
let mut output: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
output.push("FILENAME");
let file_descriptors = FileDescriptor::new_dynamic_fds(parsed.file_descriptors.clone(), &[])?;
let writers = pbjson_build::Builder::new()
.register_descriptors(&descriptor_set)?
.generate(&["."], move |package| {
output.set_file_name(format!("{}.rs", package));
// Create a list of relative paths and their corresponding file descriptors.
let relative_paths = parsed
.relative_paths
.iter()
.map(|p| {
(
file_descriptors
.iter()
.find(|&fd| fd.name().ends_with(p.to_str()))
.expect(&format!(
"find a file descriptor matching the relative path {}",
p.to_str()
))
.clone(),
p.to_path()
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_string(),
)
})
.collect();
let file = std::fs::OpenOptions::new().append(true).open(&output)?;
// Flat map the full proto names to their respective rust type name.
let message_names = map_names(relative_paths);
Ok(BufWriter::new(file))
})?;
for path in parsed.relative_paths {
// Find the relative file descriptor
let file_descriptor = parsed
.file_descriptors
.iter()
.find(|&fd| fd.name.clone().unwrap_or_default().ends_with(path.to_str()))
.expect(&format!(
"find a file descriptor matching the relative path {}",
path.to_str()
));
// TODO: This seems fragile.
let path = path.to_path().with_extension("rs");
let cargo_out_dir = std::env::var("OUT_DIR")?;
let out_dir = Path::new(&cargo_out_dir).join("connect_proto_gen");
let proto_rs_file_name = path.file_name().unwrap().to_str().unwrap();
let proto_rs_full_path = out_dir.join(&proto_rs_file_name);
// Replace all instances of "::protobuf::" with "::axum_connect::protobuf::" in the original
// generated file.
let rust = std::fs::read_to_string(&proto_rs_full_path)?;
let rust = rust.replace("::protobuf::", "::axum_connect::protobuf::");
// std::fs::write(&proto_rs_full_path, rust)?;
// Build up the service implementation file source.
let mut c = String::new();
c.push_str(FILE_PREAMBLE_TEMPLATE);
for service in &file_descriptor.service {
// Build up methods first
let mut m = String::new();
for method in &service.method {
let input_type = message_names.get(method.input_type()).unwrap();
let output_type = message_names.get(method.output_type()).unwrap();
m.push_str(
&METHOD_TEMPLATE
.replace("@@METHOD_NAME@@", &method.name().to_case(Case::Snake))
.replace("@@INPUT_TYPE@@", &input_type)
.replace("@@OUTPUT_TYPE@@", &output_type)
.replace(
"@@ROUTE@@",
&format!(
"/{}.{}/{}",
file_descriptor.package(),
service.name(),
method.name()
),
),
);
}
c.push_str(
&SERVICE_TEMPLATE
.replace("@@SERVICE_NAME@@", service.name())
.replace("@@SERVICE_METHODS@@", &m),
);
}
let mut final_file = String::new();
final_file.push_str(&rust);
final_file.push_str(&c);
std::fs::write(&proto_rs_full_path, &final_file)?;
for (_, mut writer) in writers {
writer.flush()?;
}
Ok(())
}
const FILE_PREAMBLE_TEMPLATE: &str = "// Generated by axum-connect-build
use axum::{
body::HttpBody, extract::State, http::Request, response::IntoResponse, routing::post, BoxError,
Router,
};
use axum_connect::{handler::HandlerFuture, router::RpcRouter};
";
const SERVICE_TEMPLATE: &str = "
pub struct @@SERVICE_NAME@@;
impl @@SERVICE_NAME@@ {
@@SERVICE_METHODS@@
}";
const METHOD_TEMPLATE: &str = "
pub fn @@METHOD_NAME@@<T, H, R, S, B>(handler: H) -> impl FnOnce(Router<S, B>) -> RpcRouter<S, B>
where
H: HandlerFuture<@@INPUT_TYPE@@, @@OUTPUT_TYPE@@, R, T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
move |router: Router<S, B>| {
router.route(
\"@@ROUTE@@\",
post(|State(state): State<S>, request: Request<B>| async move {
let res = handler.call(request, state).await;
res.into_response()
}),
)
}
}
";
/// Takes a vec to FileDescriptor and returns a flattened map of the full name for each message type
/// to the associated Rust name relative to `super`.
fn map_names(file_descriptors: Vec<(FileDescriptor, String)>) -> HashMap<String, String> {
let mut map = HashMap::new();
for (file_descriptor, path) in file_descriptors {
for message in file_descriptor.messages() {
collect_messages_recursive(&mut map, message, format!("super::{}", path));
}
}
map
}
fn collect_messages_recursive(
map: &mut HashMap<String, String>,
message: MessageDescriptor,
parent_rust_path: String,
) {
map.insert(
format!(".{}", message.full_name()),
format!(
"{}::{}",
parent_rust_path,
message.name().to_case(Case::UpperCamel)
),
);
for nested in message.nested_messages() {
collect_messages_recursive(
map,
nested,
format!(
"{}::{}",
parent_rust_path,
message.name().to_case(Case::Snake)
),
);
}
}

View file

@ -1,11 +1,12 @@
[package]
name = "hello_world"
name = "axum-connect-example"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.6.9"
axum-connect = { path = "../axum-connect" }
prost = "0.11.9"
tokio = { version = "1.0", features = ["full"] }
[build-dependencies]

View file

@ -1,5 +1,5 @@
use axum_connect_build::axum_connect_codegen;
fn main() {
axum_connect_codegen("proto", &["proto/hello.proto"]).unwrap();
axum_connect_codegen(&["proto"], &["proto/hello.proto"]).unwrap();
}

View file

@ -1,6 +1,6 @@
syntax = "proto3";
package blink.hello;
package hello;
message HelloRequest { string name = 1; }

View file

@ -2,10 +2,12 @@ use std::net::SocketAddr;
use axum::{extract::Host, Router};
use axum_connect::prelude::*;
use proto::hello::{HelloRequest, HelloResponse, HelloWorldService};
use proto::hello::*;
mod proto {
include!(concat!(env!("OUT_DIR"), "/connect_proto_gen/mod.rs"));
pub mod hello {
include!(concat!(env!("OUT_DIR"), "/hello.rs"));
}
}
#[tokio::main]
@ -30,6 +32,5 @@ async fn say_hello_success(Host(host): Host, request: HelloRequest) -> HelloResp
"Hello {}! You're addressing the hostname: {}.",
request.name, host
),
special_fields: Default::default(),
}
}

View file

@ -17,7 +17,6 @@ repository = "https://github.com/AThilenius/axum-connect"
async-trait = "0.1.64"
axum = "0.6.9"
futures = "0.3.26"
protobuf = "3.2.0"
protobuf-json-mapping = "3.2.0"
prost = "0.11.9"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View file

@ -2,7 +2,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use protobuf::MessageFull;
use prost::Message;
use serde::Serialize;
use crate::{prelude::RpcResponse, response::RpcIntoResponse};
@ -103,7 +103,7 @@ impl From<RpcErrorCode> for StatusCode {
impl<T> RpcIntoResponse<T> for RpcErrorCode
where
T: MessageFull,
T: Message,
{
fn rpc_into_response(self) -> RpcResponse<T> {
RpcResponse {
@ -115,7 +115,7 @@ where
impl<T> RpcIntoResponse<T> for RpcError
where
T: MessageFull,
T: Message,
{
fn rpc_into_response(self) -> RpcResponse<T> {
RpcResponse {

View file

@ -2,10 +2,8 @@ use std::pin::Pin;
use axum::{body::HttpBody, extract::FromRequest, http::Request, BoxError};
use futures::Future;
use protobuf::MessageFull;
pub use protobuf;
pub use protobuf_json_mapping;
use prost::Message;
use serde::{de::DeserializeOwned, Serialize};
pub use crate::{error::RpcIntoError, parts::RpcFromRequestParts, response::RpcIntoResponse};
use crate::{
@ -19,14 +17,13 @@ pub trait HandlerFuture<TReq, TRes, Res, T, S, B>: Clone + Send + Sized + 'stati
fn call(self, req: Request<B>, state: S) -> Self::Future;
}
// This is a single expanded version of the macro below. It's left here for ease of reading and
// understanding the macro, as well as development.
// ```rust
// This is here because writing Rust macros sucks a**. So I uncomment this when I'm trying to modify
// the below macro.
// #[allow(unused_parens, non_snake_case, unused_mut)]
// impl<TReq, TRes, Res, F, Fut, S, B, T1> HandlerFuture<TReq, TRes, Res, (T1, TReq), S, B> for F
// where
// TReq: MessageFull + Send + 'static,
// TRes: MessageFull + Send + 'static,
// TReq: Message + DeserializeOwned + Send + 'static,
// TRes: Message + Serialize + Send + 'static,
// Res: RpcIntoResponse<TRes>,
// F: FnOnce(T1, TReq) -> Fut + Clone + Send + 'static,
// Fut: Future<Output = Res> + Send,
@ -58,7 +55,7 @@ pub trait HandlerFuture<TReq, TRes, Res, T, S, B>: Clone + Send + Sized + 'stati
// }
// };
// let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) {
// let proto_req: TReq = match serde_json::from_str(&body) {
// Ok(value) => value,
// Err(_e) => {
// return RpcError::new(
@ -75,7 +72,7 @@ pub trait HandlerFuture<TReq, TRes, Res, T, S, B>: Clone + Send + Sized + 'stati
// })
// }
// }
// ```
macro_rules! impl_handler {
(
[$($ty:ident),*]
@ -84,8 +81,8 @@ macro_rules! impl_handler {
impl<TReq, TRes, Res, F, Fut, S, B, $($ty,)*>
HandlerFuture<TReq, TRes, Res, ($($ty,)* TReq), S, B> for F
where
TReq: MessageFull + Send + 'static,
TRes: MessageFull + Send + 'static,
TReq: Message + DeserializeOwned + Send + 'static,
TRes: Message + Serialize + Send + 'static,
Res: RpcIntoResponse<TRes>,
F: FnOnce($($ty,)* TReq) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
@ -119,7 +116,7 @@ macro_rules! impl_handler {
}
};
let proto_req: TReq = match protobuf_json_mapping::parse_from_str(&body) {
let proto_req: TReq = match serde_json::from_str(&body) {
Ok(value) => value,
Err(_e) => {
return RpcError::new(

View file

@ -1,13 +1,13 @@
// Re-export protobuf and protobuf_json_mapping for downstream use.
pub use protobuf;
pub use protobuf_json_mapping;
pub mod error;
pub mod handler;
pub mod parts;
pub mod response;
pub mod router;
// Re-export both prost and serde.
pub use prost;
pub use serde;
pub mod prelude {
pub use crate::error::*;
pub use crate::parts::*;

View file

@ -6,7 +6,7 @@ use axum::{
http::{self},
Extension,
};
use protobuf::MessageFull;
use prost::Message;
use serde::de::DeserializeOwned;
use crate::error::{RpcError, RpcErrorCode, RpcIntoError};
@ -14,7 +14,7 @@ use crate::error::{RpcError, RpcErrorCode, RpcIntoError};
#[async_trait]
pub trait RpcFromRequestParts<T, S>: Sized
where
T: MessageFull,
T: Message,
S: Send + Sync,
{
/// If the extractor fails it'll use this "rejection" type. A rejection is
@ -31,7 +31,7 @@ where
#[async_trait]
impl<M, S> RpcFromRequestParts<M, S> for Host
where
M: MessageFull,
M: Message,
S: Send + Sync,
{
type Rejection = RpcError;
@ -49,7 +49,7 @@ where
#[async_trait]
impl<M, S, T> RpcFromRequestParts<M, S> for Query<T>
where
M: MessageFull,
M: Message,
S: Send + Sync,
T: DeserializeOwned,
{
@ -68,7 +68,7 @@ where
#[async_trait]
impl<M, S, T> RpcFromRequestParts<M, S> for ConnectInfo<T>
where
M: MessageFull,
M: Message,
S: Send + Sync,
T: Clone + Send + Sync + 'static,
{
@ -91,7 +91,7 @@ where
#[async_trait]
impl<M, OuterState, InnerState> RpcFromRequestParts<M, OuterState> for State<InnerState>
where
M: MessageFull,
M: Message,
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{

View file

@ -1,5 +1,6 @@
use axum::response::{IntoResponse, Response};
use protobuf::MessageFull;
use prost::Message;
use serde::Serialize;
use crate::error::{RpcError, RpcErrorCode, RpcIntoError};
@ -12,12 +13,12 @@ pub struct RpcResponse<T> {
impl<T> IntoResponse for RpcResponse<T>
where
T: MessageFull,
T: Message + Serialize,
{
fn into_response(self) -> Response {
let rpc_call_response: Response = {
match self.response {
Ok(value) => protobuf_json_mapping::print_to_string(&value)
Ok(value) => serde_json::to_string(&value)
.map_err(|_e| {
RpcError::new(
RpcErrorCode::Internal,
@ -43,14 +44,14 @@ where
pub trait RpcIntoResponse<T>: Send + Sync + 'static
where
T: MessageFull,
T: Message,
{
fn rpc_into_response(self) -> RpcResponse<T>;
}
impl<T> RpcIntoResponse<T> for T
where
T: MessageFull,
T: Message + 'static,
{
fn rpc_into_response(self) -> RpcResponse<T> {
RpcResponse {
@ -62,7 +63,7 @@ where
impl<T, E> RpcIntoResponse<T> for Result<T, E>
where
T: MessageFull,
T: Message + 'static,
E: RpcIntoError + Send + Sync + 'static,
{
fn rpc_into_response(self) -> RpcResponse<T> {