diff --git a/Cargo.toml b/Cargo.toml index 0375cddd19..6d79d4445e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ members = [ "qcow_utils", "resources", "rutabaga_gfx", + "serde_keyvalue", "system_api_stub", "tpm2", "tpm2-sys", @@ -172,6 +173,7 @@ remain = "*" resources = { path = "resources" } scudo = { version = "0.1", optional = true } serde_json = "*" +serde_keyvalue = { path = "serde_keyvalue" } sync = { path = "common/sync" } tempfile = "3" thiserror = { version = "1.0.20" } diff --git a/serde_keyvalue/Cargo.toml b/serde_keyvalue/Cargo.toml new file mode 100644 index 0000000000..838a368994 --- /dev/null +++ b/serde_keyvalue/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "serde_keyvalue" +version = "0.1.0" +authors = ["The Chromium OS Authors"] +edition = "2021" + +[features] +argh_derive = ["argh", "serde_keyvalue_derive"] + +[dependencies] +argh = { version = "0.1.7", optional = true } +serde_keyvalue_derive = { path = "serde_keyvalue_derive", optional = true } +serde = "1" +thiserror = { version = "1.0.20" } +remain = "*" + +[dev-dependencies] +serde = { version = "1", features = ["derive"] } \ No newline at end of file diff --git a/serde_keyvalue/README.md b/serde_keyvalue/README.md new file mode 100644 index 0000000000..30498cf020 --- /dev/null +++ b/serde_keyvalue/README.md @@ -0,0 +1,19 @@ +# Serde deserializer from key=value strings + +A lightweight serde deserializer for strings containing key-value pairs separated by commas, as +commonly found in command-line parameters. + +Say your program takes a command-line option of the form: + +```text +--foo type=bar,active,nb_threads=8 +``` + +This crate provides a `from_key_values` function that deserializes these key-values into a +configuration structure. Since it uses serde, the same configuration structure can also be created +from any other supported source (such as a TOML or YAML configuration file) that uses the same keys. + +Integration with the [argh](https://github.com/google/argh) command-line parser is also provided via +the `argh_derive` feature. + +See the inline documentation for examples and more details. diff --git a/serde_keyvalue/serde_keyvalue_derive/Cargo.toml b/serde_keyvalue/serde_keyvalue_derive/Cargo.toml new file mode 100644 index 0000000000..1ece8456d2 --- /dev/null +++ b/serde_keyvalue/serde_keyvalue_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "serde_keyvalue_derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +argh = "0.1.7" +proc-macro2 = "1.0" +syn = "1.0" +quote = "1.0" diff --git a/serde_keyvalue/serde_keyvalue_derive/src/lib.rs b/serde_keyvalue/serde_keyvalue_derive/src/lib.rs new file mode 100644 index 0000000000..dfb605250d --- /dev/null +++ b/serde_keyvalue/serde_keyvalue_derive/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use quote::quote; +use syn::{parse_macro_input, DeriveInput}; + +/// Implement `argh`'s `FromArgValue` trait for a struct or enum using `from_key_values`. +#[proc_macro_derive(FromKeyValues)] +pub fn keyvalues_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let DeriveInput { + ident, generics, .. + } = parse_macro_input!(input); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + quote! { + impl #impl_generics ::serde_keyvalue::argh::FromArgValue for #ident #ty_generics #where_clause { + fn from_arg_value(value: &str) -> std::result::Result { + ::serde_keyvalue::from_key_values(value).map_err(|e| e.to_string()) + } + } + } + .into() +} diff --git a/serde_keyvalue/src/key_values.rs b/serde_keyvalue/src/key_values.rs new file mode 100644 index 0000000000..415a66be71 --- /dev/null +++ b/serde_keyvalue/src/key_values.rs @@ -0,0 +1,1048 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::borrow::Cow; +use std::fmt::{self, Debug, Display}; +use std::num::{IntErrorKind, ParseIntError}; +use std::str::FromStr; + +use remain::sorted; +use serde::de; +use serde::Deserialize; +use thiserror::Error; + +#[derive(Debug, Error, PartialEq)] +#[sorted] +#[non_exhaustive] +#[allow(missing_docs)] +/// Different kinds of errors that can be returned by the parser. +pub enum ErrorKind { + #[error("unexpected end of input")] + Eof, + #[error("expected a boolean")] + ExpectedBoolean, + #[error("expected ','")] + ExpectedComma, + #[error("expected '='")] + ExpectedEqual, + #[error("expected an identifier")] + ExpectedIdentifier, + #[error("expected a number")] + ExpectedNumber, + #[error("expected a string")] + ExpectedString, + #[error("\" and ' can only be used in quoted strings")] + InvalidCharInString, + #[error("non-terminated string")] + NonTerminatedString, + #[error("provided number does not fit in the destination type")] + NumberOverflow, + #[error("serde error: {0}")] + SerdeError(String), + #[error("remaining characters in input")] + TrailingCharacters, +} + +/// Error that may be thown while parsing a key-values string. +#[derive(Debug, Error, PartialEq)] +pub struct ParseError { + /// Detailed error that occurred. + pub kind: ErrorKind, + /// Index of the error in the input string. + pub pos: usize, +} + +impl Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::SerdeError(s) => write!(f, "{}", s), + _ => write!(f, "{} at position {}", self.kind, self.pos), + } + } +} + +impl de::Error for ParseError { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + Self { + kind: ErrorKind::SerdeError(msg.to_string()), + pos: 0, + } + } +} + +type Result = std::result::Result; + +/// Serde deserializer for key-values strings. +struct KeyValueDeserializer<'de> { + /// Full input originally received for parsing. + original_input: &'de str, + /// Input currently remaining to parse. + input: &'de str, + /// If set, then `deserialize_identifier` will take and return its content the next time it is + /// called instead of trying to parse an identifier from the input. This is needed to allow the + /// name of the first field of a struct to be omitted, e.g. + /// + /// --block "/path/to/disk.img,ro=true" + /// + /// instead of + /// + /// --block "path=/path/to/disk.img,ro=true" + next_identifier: Option<&'de str>, +} + +impl<'de> From<&'de str> for KeyValueDeserializer<'de> { + fn from(input: &'de str) -> Self { + Self { + original_input: input, + input, + next_identifier: None, + } + } +} + +impl<'de> KeyValueDeserializer<'de> { + /// Return an `kind` error for the current position of the input. + fn error_here(&self, kind: ErrorKind) -> ParseError { + ParseError { + kind, + pos: self.original_input.len() - self.input.len(), + } + } + + /// Returns the next char in the input string without consuming it, or None + /// if we reached the end of input. + fn peek_char(&self) -> Option { + self.input.chars().next() + } + + /// Skip the next char in the input string. + fn skip_char(&mut self) { + let _ = self.next_char(); + } + + /// Returns the next char in the input string and consume it, or returns + /// None if we reached the end of input. + fn next_char(&mut self) -> Option { + let c = self.peek_char()?; + self.input = &self.input[c.len_utf8()..]; + Some(c) + } + + /// Try to peek the next element in the input as an identifier, without consuming it. + /// + /// Returns the parsed indentifier, an `ExpectedIdentifier` error if the next element is not + /// an identifier, or `Eof` if we were at the end of the input string. + fn peek_identifier(&self) -> Result<&'de str> { + // End of input? + if self.input.is_empty() { + return Err(self.error_here(ErrorKind::Eof)); + } + + let res = self.input; + let mut len = 0; + let mut iter = self.input.chars(); + loop { + match iter.next() { + None | Some(',' | '=') => break, + Some(c) if c.is_ascii_alphanumeric() || c == '_' || (c == '-' && len > 0) => { + len += c.len_utf8(); + } + Some(_) => return Err(self.error_here(ErrorKind::ExpectedIdentifier)), + } + } + + // An identifier cannot be empty. + if len == 0 { + Err(self.error_here(ErrorKind::ExpectedIdentifier)) + } else { + Ok(&res[0..len]) + } + } + + /// Peek the next value, i.e. anything until the next comma or the end of the input string. + /// + /// This can be used to reliably peek any value, except strings which may contain commas in + /// quotes. + fn peek_value(&self) -> Result<&'de str> { + let res = self.input; + let mut len = 0; + let mut iter = self.input.chars(); + loop { + match iter.next() { + None | Some(',') => break, + Some(c) => len += c.len_utf8(), + } + } + + if len > 0 { + Ok(&res[0..len]) + } else { + Err(self.error_here(ErrorKind::Eof)) + } + } + + /// Attempts to parse an identifier, either for a key or for the value of an enum type. + /// + /// Usually identifiers are not allowed to start with a number, but we chose to allow this + /// here otherwise options like "mode=2d" won't parse if "2d" is an alias for an enum variant. + fn parse_identifier(&mut self) -> Result<&'de str> { + let res = self.peek_identifier()?; + self.input = &self.input[res.len()..]; + Ok(res) + } + + /// Attempts to parse a string. + /// + /// A string can be quoted (using single or double quotes) or not. If it is not, we consume + /// input until the next ',' separating character. If it is, we consume input until the next + /// non-escaped quote. + /// + /// The returned value is a slice into the current input if no characters to unescape were met, + /// or a fully owned string if we had to unescape some characters. + fn parse_string(&mut self) -> Result> { + let (s, quote) = match self.peek_char() { + // Beginning of quoted string. + quote @ Some('"' | '\'') => { + // Safe because we just matched against `Some`. + let quote = quote.unwrap(); + // Skip the opening quote. + self.skip_char(); + let mut len = 0; + let mut iter = self.input.chars(); + let mut escaped = false; + loop { + let c = match iter.next() { + Some('\\') if !escaped => { + escaped = true; + '\\' + } + // Found end of quoted string if we meet a non-escaped quote. + Some(c) if c == quote && !escaped => break, + Some(c) => { + escaped = false; + c + } + None => return Err(self.error_here(ErrorKind::NonTerminatedString)), + }; + len += c.len_utf8(); + } + let s = &self.input[0..len]; + self.input = &self.input[len..]; + // Skip the closing quote + self.skip_char(); + (s, Some(quote)) + } + // Empty strings must use quotes. + None | Some(',') => return Err(self.error_here(ErrorKind::ExpectedString)), + // Non-quoted string. + Some(_) => { + let s = self + .input + .split(&[',', '"', '\'']) + .next() + .unwrap_or(self.input); + self.input = &self.input[s.len()..]; + // If a string was not quoted, it shall not contain a quote. + if let Some('"' | '\'') = self.peek_char() { + return Err(self.error_here(ErrorKind::InvalidCharInString)); + } + (s, None) + } + }; + + if quote.is_some() { + let mut escaped = false; + let unescaped_string: String = s + .chars() + .filter_map(|c| match c { + '\\' if !escaped => { + escaped = true; + None + } + c => { + escaped = false; + Some(c) + } + }) + .collect(); + Ok(Cow::Owned(unescaped_string)) + } else { + Ok(Cow::Borrowed(s)) + } + } + + /// A boolean can be 'true', 'false', or nothing (which is equivalent to 'true'). + fn parse_bool(&mut self) -> Result { + // 'true' and 'false' can be picked by peek_value. + let s = match self.peek_value() { + Ok(s) => s, + // Consider end of input as an empty string, which will be evaluated to `true`. + Err(ParseError { + kind: ErrorKind::Eof, + .. + }) => "", + Err(_) => return Err(self.error_here(ErrorKind::ExpectedBoolean)), + }; + let res = match s { + "" => Ok(true), + s => bool::from_str(s).map_err(|_| self.error_here(ErrorKind::ExpectedBoolean)), + }; + + self.input = &self.input[s.len()..]; + + res + } + + /// Parse a positive or negative number. + // TODO support 0x or 0b notation? + fn parse_number(&mut self) -> Result + where + T: FromStr, + { + let num_str = self.peek_value()?; + let val = T::from_str(num_str).map_err(|e| { + self.error_here( + if let IntErrorKind::PosOverflow | IntErrorKind::NegOverflow = e.kind() { + ErrorKind::NumberOverflow + } else { + ErrorKind::ExpectedNumber + }, + ) + })?; + self.input = &self.input[num_str.len()..]; + Ok(val) + } +} + +impl<'de> de::MapAccess<'de> for KeyValueDeserializer<'de> { + type Error = ParseError; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: de::DeserializeSeed<'de>, + { + let has_next_identifier = self.next_identifier.is_some(); + + if self.peek_char().is_none() { + return Ok(None); + } + let val = seed.deserialize(&mut *self).map(Some)?; + + // We just "deserialized" the content of `next_identifier`, so there should be no equal + // character in the input. We can return now. + if has_next_identifier { + return Ok(val); + } + + match self.peek_char() { + // We expect an equal after an identifier. + Some('=') => { + self.skip_char(); + Ok(val) + } + // Ok if we are parsing a boolean where an empty value means true. + Some(',') | None => Ok(val), + Some(_) => Err(self.error_here(ErrorKind::ExpectedEqual)), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + let val = seed.deserialize(&mut *self)?; + + // We must have a comma or end of input after a value. + match self.next_char() { + Some(',') | None => Ok(val), + Some(_) => Err(self.error_here(ErrorKind::ExpectedComma)), + } + } +} + +struct Enum<'a, 'de: 'a>(&'a mut KeyValueDeserializer<'de>); + +impl<'a, 'de> Enum<'a, 'de> { + fn new(de: &'a mut KeyValueDeserializer<'de>) -> Self { + Self(de) + } +} + +impl<'a, 'de> de::EnumAccess<'de> for Enum<'a, 'de> { + type Error = ParseError; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: de::DeserializeSeed<'de>, + { + let val = seed.deserialize(&mut *self.0)?; + Ok((val, self)) + } +} + +impl<'a, 'de> de::VariantAccess<'de> for Enum<'a, 'de> { + type Error = ParseError; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + unimplemented!() + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unimplemented!() + } + + fn struct_variant(self, _fields: &'static [&'static str], _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + unimplemented!() + } +} + +impl<'de, 'a> de::Deserializer<'de> for &'a mut KeyValueDeserializer<'de> { + type Error = ParseError; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.peek_char() { + Some('0'..='9') => self.deserialize_u64(visitor), + Some('-') => self.deserialize_i64(visitor), + Some('"') => self.deserialize_string(visitor), + // Only possible option here is boolean flag. + Some(',') | None => self.deserialize_bool(visitor), + _ => { + // We probably have an unquoted string, but possibly a boolean as well. + match self.peek_identifier() { + Ok("true") | Ok("false") => self.deserialize_bool(visitor), + _ => self.deserialize_str(visitor), + } + } + } + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_bool(self.parse_bool()?) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_i8(self.parse_number()?) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_i16(self.parse_number()?) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_i32(self.parse_number()?) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_i64(self.parse_number()?) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_u8(self.parse_number()?) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_u16(self.parse_number()?) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_u32(self.parse_number()?) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_u64(self.parse_number()?) + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_char(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_char( + self.next_char() + .ok_or_else(|| self.error_here(ErrorKind::Eof))?, + ) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.parse_string()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_string(s), + } + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_bytes(visitor) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + // The fact that an option is specified implies that is exists, hence we always visit + // Some() here. + visitor.visit_some(self) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_unit(visitor) + } + + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + unimplemented!() + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_map(self) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + // The name of the first field of a struct can be omitted (see documentation of + // `next_identifier` for details). + // + // To detect this, peek the next identifier, and check if the character following is '='. If + // it is not, then we may have a value in first position, unless the value is identical to + // one of the field's name - in this case, assume this is a boolean using the flag syntax. + self.next_identifier = match self.peek_identifier() { + Ok(s) => match self.input.chars().nth(s.chars().count()) { + Some('=') => None, + _ => { + if fields.contains(&s) { + None + } else { + fields.get(0).copied() + } + } + }, + // Not an identifier, probably means this is a value for the first field then. + Err(_) => fields.get(0).copied(), + }; + visitor.visit_map(self) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_enum(Enum::new(self)) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let identifier = self + .next_identifier + .take() + .map_or_else(|| self.parse_identifier(), Ok)?; + + visitor.visit_borrowed_str(identifier) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_any(visitor) + } +} + +/// Attempts to deserialize `T` from the key-values string `input`. +pub fn from_key_values<'a, T>(input: &'a str) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = KeyValueDeserializer::from(input); + let t = T::deserialize(&mut deserializer)?; + if deserializer.input.is_empty() { + Ok(t) + } else { + Err(deserializer.error_here(ErrorKind::TrailingCharacters)) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + + #[derive(Deserialize, PartialEq, Debug)] + struct SingleStruct { + m: T, + } + + #[test] + fn deserialize_number() { + let res = from_key_values::>("m=54").unwrap(); + assert_eq!(res.m, 54); + + let res = from_key_values::>("m=-54").unwrap(); + assert_eq!(res.m, -54); + + // Parsing a signed into an unsigned? + let res = from_key_values::>("m=-54").unwrap_err(); + assert_eq!( + res, + ParseError { + kind: ErrorKind::ExpectedNumber, + pos: 2 + } + ); + + // Value too big for a signed? + let val = i32::MAX as u32 + 1; + let res = from_key_values::>(&format!("m={}", val)).unwrap_err(); + assert_eq!( + res, + ParseError { + kind: ErrorKind::NumberOverflow, + pos: 2 + } + ); + + let res = from_key_values::>("m=test").unwrap_err(); + assert_eq!( + res, + ParseError { + kind: ErrorKind::ExpectedNumber, + pos: 2, + } + ); + } + + #[test] + fn deserialize_string() { + let kv = "m=John"; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "John".to_string()); + + // Spaces are valid (but not recommended) in unquoted strings. + let kv = "m=John Doe"; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "John Doe".to_string()); + + // Empty string is not valid if unquoted + let kv = "m="; + let err = from_key_values::>(kv).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: ErrorKind::ExpectedString, + pos: 2 + } + ); + + // Quoted strings. + let kv = r#"m="John Doe""#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "John Doe".to_string()); + let kv = r#"m='John Doe'"#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "John Doe".to_string()); + + // Empty quoted strings. + let kv = r#"m="""#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "".to_string()); + let kv = r#"m=''"#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "".to_string()); + + // "=", "," and "'"" in quote. + let kv = r#"m="val = [10, 20, 'a']""#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, r#"val = [10, 20, 'a']"#.to_string()); + + // Quotes in unquoted strings are forbidden. + let kv = r#"m=val="a""#; + let err = from_key_values::>(kv).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: ErrorKind::InvalidCharInString, + pos: 6 + } + ); + let kv = r#"m=val='a'"#; + let err = from_key_values::>(kv).unwrap_err(); + assert_eq!( + err, + ParseError { + kind: ErrorKind::InvalidCharInString, + pos: 6 + } + ); + + // Numbers and booleans are technically valid strings. + let kv = "m=10"; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "10".to_string()); + let kv = "m=false"; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, "false".to_string()); + + // Escaped quote. + let kv = r#"m="Escaped \" quote""#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, r#"Escaped " quote"#.to_string()); + + // Escaped slash at end of string. + let kv = r#"m="Escaped slash\\""#; + let res = from_key_values::>(kv).unwrap(); + assert_eq!(res.m, r#"Escaped slash\"#.to_string()); + } + + #[test] + fn deserialize_bool() { + let res = from_key_values::>("m=true").unwrap(); + assert_eq!(res.m, true); + + let res = from_key_values::>("m=false").unwrap(); + assert_eq!(res.m, false); + + let res = from_key_values::>("m").unwrap(); + assert_eq!(res.m, true); + + let res = from_key_values::>("m=10").unwrap_err(); + assert_eq!( + res, + ParseError { + kind: ErrorKind::ExpectedBoolean, + pos: 2, + } + ); + } + + #[test] + fn deserialize_complex_struct() { + #[derive(Deserialize, PartialEq, Debug)] + struct TestStruct { + num: usize, + path: PathBuf, + enable: bool, + } + let kv = "num=54,path=/dev/foomatic,enable=false"; + let res = from_key_values::(kv).unwrap(); + assert_eq!( + res, + TestStruct { + num: 54, + path: "/dev/foomatic".into(), + enable: false, + } + ); + + let kv = "enable,path=/usr/lib/libossom.so.1,num=12"; + let res = from_key_values::(kv).unwrap(); + assert_eq!( + res, + TestStruct { + num: 12, + path: "/usr/lib/libossom.so.1".into(), + enable: true, + } + ); + } + + #[test] + fn deserialize_unknown_field() { + #[derive(Deserialize, PartialEq, Debug)] + #[serde(deny_unknown_fields)] + struct TestStruct { + num: usize, + path: PathBuf, + enable: bool, + } + + let kv = "enable,path=/usr/lib/libossom.so.1,num=12,foo=bar"; + assert!(from_key_values::(kv).is_err()); + } + + #[test] + fn deserialize_option() { + #[derive(Deserialize, PartialEq, Debug)] + struct TestStruct { + num: u32, + opt: Option, + } + let kv = "num=16,opt=12"; + let res: TestStruct = from_key_values(kv).unwrap(); + assert_eq!( + res, + TestStruct { + num: 16, + opt: Some(12), + } + ); + + let kv = "num=16"; + let res: TestStruct = from_key_values(kv).unwrap(); + assert_eq!(res, TestStruct { num: 16, opt: None }); + + let kv = ""; + assert!(from_key_values::(kv).is_err()); + } + + #[test] + fn deserialize_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum TestEnum { + #[serde(rename = "first")] + FirstVariant, + #[serde(rename = "second")] + SecondVariant, + } + let res: TestEnum = from_key_values("first").unwrap(); + assert_eq!(res, TestEnum::FirstVariant,); + + let res: TestEnum = from_key_values("second").unwrap(); + assert_eq!(res, TestEnum::SecondVariant,); + + from_key_values::("third").unwrap_err(); + } + + #[test] + fn deserialize_embedded_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum TestEnum { + #[serde(rename = "first")] + FirstVariant, + #[serde(rename = "second")] + SecondVariant, + } + #[derive(Deserialize, PartialEq, Debug)] + struct TestStruct { + variant: TestEnum, + #[serde(default)] + active: bool, + } + let res: TestStruct = from_key_values("variant=first").unwrap(); + assert_eq!( + res, + TestStruct { + variant: TestEnum::FirstVariant, + active: false, + } + ); + let res: TestStruct = from_key_values("variant=second,active=true").unwrap(); + assert_eq!( + res, + TestStruct { + variant: TestEnum::SecondVariant, + active: true, + } + ); + let res: TestStruct = from_key_values("active=true,variant=second").unwrap(); + assert_eq!( + res, + TestStruct { + variant: TestEnum::SecondVariant, + active: true, + } + ); + let res: TestStruct = from_key_values("active,variant=second").unwrap(); + assert_eq!( + res, + TestStruct { + variant: TestEnum::SecondVariant, + active: true, + } + ); + let res: TestStruct = from_key_values("active=false,variant=second").unwrap(); + assert_eq!( + res, + TestStruct { + variant: TestEnum::SecondVariant, + active: false, + } + ); + } + + #[test] + fn deserialize_first_arg_string() { + #[derive(Deserialize, PartialEq, Debug)] + struct TestStruct { + name: String, + num: u8, + } + let res: TestStruct = from_key_values("name=foo,num=12").unwrap(); + assert_eq!( + res, + TestStruct { + name: "foo".into(), + num: 12, + } + ); + + let res: TestStruct = from_key_values("foo,num=12").unwrap(); + assert_eq!( + res, + TestStruct { + name: "foo".into(), + num: 12, + } + ); + } + + #[test] + fn deserialize_first_arg_int() { + #[derive(Deserialize, PartialEq, Debug)] + struct TestStruct { + num: u8, + name: String, + } + let res: TestStruct = from_key_values("name=foo,num=12").unwrap(); + assert_eq!( + res, + TestStruct { + num: 12, + name: "foo".into(), + } + ); + + let res: TestStruct = from_key_values("12,name=foo").unwrap(); + assert_eq!( + res, + TestStruct { + num: 12, + name: "foo".into(), + } + ); + } +} diff --git a/serde_keyvalue/src/lib.rs b/serde_keyvalue/src/lib.rs new file mode 100644 index 0000000000..51d9d37f2e --- /dev/null +++ b/serde_keyvalue/src/lib.rs @@ -0,0 +1,295 @@ +// Copyright 2022 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! A lightweight serde deserializer for strings containing key-value pairs separated by commas, as +//! commonly found in command-line parameters. +//! +//! Say your program takes a command-line option of the form: +//! +//! ```text +//! --foo type=bar,active,nb_threads=8 +//! ``` +//! +//! This crate provides a [from_key_values] function that deserializes these key-values into a +//! configuration structure. Since it uses serde, the same configuration structure can also be +//! created from any other supported source (such as a TOML or YAML configuration file) that uses +//! the same keys. +//! +//! Integration with the [argh](https://github.com/google/argh) command-line parser is also +//! provided via the `argh_derive` feature. +//! +//! The deserializer supports parsing signed and unsigned integers, booleans, strings (quoted or +//! not), paths, and enums inside a top-level struct. The order in which the fields appear in the +//! string is not important. +//! +//! Simple example: +//! +//! ``` +//! use serde_keyvalue::from_key_values; +//! use serde::Deserialize; +//! +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: String, +//! threads: u8, +//! active: bool, +//! } +//! +//! let config: Config = from_key_values("path=/some/path,threads=16,active=true").unwrap(); +//! assert_eq!(config, Config { path: "/some/path".into(), threads: 16, active: true }); +//! +//! let config: Config = from_key_values("threads=16,active=true,path=/some/path").unwrap(); +//! assert_eq!(config, Config { path: "/some/path".into(), threads: 16, active: true }); +//! ``` +//! +//! As a convenience the name of the first field of a struct can be omitted: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: String, +//! threads: u8, +//! active: bool, +//! } +//! +//! let config: Config = from_key_values("/some/path,threads=16,active=true").unwrap(); +//! assert_eq!(config, Config { path: "/some/path".into(), threads: 16, active: true }); +//! ``` +//! +//! Fields that are behind an `Option` can be omitted, in which case they will be `None`. +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: Option, +//! threads: u8, +//! active: bool, +//! } +//! +//! let config: Config = from_key_values("path=/some/path,threads=16,active=true").unwrap(); +//! assert_eq!(config, Config { path: Some("/some/path".into()), threads: 16, active: true }); +//! +//! let config: Config = from_key_values("threads=16,active=true").unwrap(); +//! assert_eq!(config, Config { path: None, threads: 16, active: true }); +//! ``` +//! +//! Alternatively, the serde `default` attribute can be used on select fields or on the whole +//! struct to make unspecified fields be assigned their default value. In the following example only +//! the `path` parameter must be specified. +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: String, +//! #[serde(default)] +//! threads: u8, +//! #[serde(default)] +//! active: bool, +//! } +//! +//! let config: Config = from_key_values("path=/some/path").unwrap(); +//! assert_eq!(config, Config { path: "/some/path".into(), threads: 0, active: false }); +//! ``` +//! +//! A function providing a default value can also be specified, see the [serde documentation for +//! field attributes](https://serde.rs/field-attrs.html) for details. +//! +//! Booleans can be `true` or `false`, or take no value at all, in which case they will be `true`. +//! Combined with default values this allows to implement flags very easily: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, Default, PartialEq, Deserialize)] +//! #[serde(default)] +//! struct Config { +//! active: bool, +//! delayed: bool, +//! pooled: bool, +//! } +//! +//! let config: Config = from_key_values("active=true,delayed=false,pooled=true").unwrap(); +//! assert_eq!(config, Config { active: true, delayed: false, pooled: true }); +//! +//! let config: Config = from_key_values("active,pooled").unwrap(); +//! assert_eq!(config, Config { active: true, delayed: false, pooled: true }); +//! ``` +//! +//! Strings can be quoted, which is useful if they e.g. need to include a comma. Quoted strings can +//! also contain escaped characters, where any character after a `\` is repeated as-is: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: String, +//! } +//! +//! let config: Config = from_key_values(r#"path="/some/\"strange\"/pa,th""#).unwrap(); +//! assert_eq!(config, Config { path: r#"/some/"strange"/pa,th"#.into() }); +//! ``` +//! +//! Enums can be directly specified by name. It is recommended to use the `rename_all` serde +//! container attribute to make them parseable using snake or kebab case representation. Serde's +//! `rename` and `alias` field attributes can also be used to provide shorter values: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! #[serde(rename_all="kebab-case")] +//! enum Mode { +//! Slow, +//! Fast, +//! #[serde(rename="ludicrous")] +//! LudicrousSpeed, +//! } +//! +//! #[derive(Deserialize, PartialEq, Debug)] +//! struct Config { +//! mode: Mode, +//! } +//! +//! let config: Config = from_key_values("mode=slow").unwrap(); +//! assert_eq!(config, Config { mode: Mode::Slow }); +//! +//! let config: Config = from_key_values("mode=ludicrous").unwrap(); +//! assert_eq!(config, Config { mode: Mode::LudicrousSpeed }); +//! ``` +//! +//! Enums taking a single value should use the `flatten` field attribute in order to be inferred +//! from their variant key directly: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! #[serde(rename_all="kebab-case")] +//! enum Mode { +//! // Work with a local file. +//! File(String), +//! // Work with a remote URL. +//! Url(String), +//! } +//! +//! #[derive(Deserialize, PartialEq, Debug)] +//! struct Config { +//! #[serde(flatten)] +//! mode: Mode, +//! } +//! +//! let config: Config = from_key_values("file=/some/path").unwrap(); +//! assert_eq!(config, Config { mode: Mode::File("/some/path".into()) }); +//! +//! let config: Config = from_key_values("url=https://www.google.com").unwrap(); +//! assert_eq!(config, Config { mode: Mode::Url("https://www.google.com".into()) }); +//! ``` +//! +//! The `flatten` attribute can also be used to embed one struct within another one and parse both +//! from the same string: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct BaseConfig { +//! enabled: bool, +//! num_threads: u8, +//! } +//! +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! #[serde(flatten)] +//! base: BaseConfig, +//! path: String, +//! } +//! +//! let config: Config = from_key_values("path=/some/path,enabled,num_threads=16").unwrap(); +//! assert_eq!( +//! config, +//! Config { +//! path: "/some/path".into(), +//! base: BaseConfig { +//! num_threads: 16, +//! enabled: true, +//! } +//! } +//! ); +//! ``` +//! +//! If an enum's variants are made of structs, it should take the `untagged` container attribute so +//! it can be inferred directly from the fields of the embedded structs: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! #[serde(untagged)] +//! enum Mode { +//! // Work with a local file. +//! File { +//! path: String, +//! #[serde(default)] +//! read_only: bool, +//! }, +//! // Work with a remote URL. +//! Remote { +//! server: String, +//! port: u16, +//! } +//! } +//! +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! #[serde(flatten)] +//! mode: Mode, +//! } +//! +//! let config: Config = from_key_values("path=/some/path").unwrap(); +//! assert_eq!(config, Config { mode: Mode::File { path: "/some/path".into(), read_only: false } }); +//! +//! let config: Config = from_key_values("server=google.com,port=80").unwrap(); +//! assert_eq!(config, Config { mode: Mode::Remote { server: "google.com".into(), port: 80 } }); +//! ``` +//! +//! Using this crate, parsing errors and invalid or missing fields are precisely reported: +//! +//! ``` +//! # use serde_keyvalue::from_key_values; +//! # use serde::Deserialize; +//! #[derive(Debug, PartialEq, Deserialize)] +//! struct Config { +//! path: String, +//! threads: u8, +//! active: bool, +//! } +//! +//! let config = from_key_values::("path=/some/path,active=true").unwrap_err(); +//! assert_eq!(format!("{}", config), "missing field `threads`"); +//! ``` +//! +//! Most of the serde [container](https://serde.rs/container-attrs.html) and +//! [field](https://serde.rs/field-attrs.html) attributes can be applied to your configuration +//! struct. Most useful ones include +//! [`deny_unknown_fields`](https://serde.rs/container-attrs.html#deny_unknown_fields) to report an +//! error if an unknown field is met in the input, and +//! [`deserialize_with`](https://serde.rs/field-attrs.html#deserialize_with) to use a custom +//! deserialization function for a specific field. +#![deny(missing_docs)] + +mod key_values; + +pub use key_values::{from_key_values, ErrorKind, ParseError}; + +#[cfg(feature = "argh_derive")] +pub use argh; +#[cfg(feature = "argh_derive")] +pub use serde_keyvalue_derive::FromKeyValues;