From 397ee49bbff395720c57b8600c9051f69c8ef264 Mon Sep 17 00:00:00 2001 From: Dermot Haughey Date: Wed, 16 Nov 2022 16:06:30 -0600 Subject: [PATCH] feat: adds hmget command Adds `hmget` command to resp protocol implementation. --- src/protocol/resp/src/request/hmget.rs | 137 +++++++++++++++++++++++++ src/protocol/resp/src/request/mod.rs | 15 +++ 2 files changed, 152 insertions(+) create mode 100644 src/protocol/resp/src/request/hmget.rs diff --git a/src/protocol/resp/src/request/hmget.rs b/src/protocol/resp/src/request/hmget.rs new file mode 100644 index 000000000..ab7eb483a --- /dev/null +++ b/src/protocol/resp/src/request/hmget.rs @@ -0,0 +1,137 @@ +use super::*; +use std::io::{Error, ErrorKind}; +use std::sync::Arc; + +type ArcByteSlice = Arc>; +#[derive(Debug, PartialEq, Eq)] +pub struct HmGetRequest { + key: ArcByteSlice, + fields: Arc>, +} + +impl HmGetRequest { + pub fn key(&self) -> &[u8] { + &self.key + } + + pub fn fields(&self) -> Box<[&[u8]]> { + self.fields + .iter() + .map(|f| &***f) + .collect::>() + .into_boxed_slice() + } +} + +impl TryFrom for HmGetRequest { + type Error = Error; + + fn try_from(other: Message) -> Result { + if let Message::Array(array) = other { + if array.inner.is_none() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let mut array = array.inner.unwrap(); + + if array.len() <= 2 { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let key = take_bulk_string(&mut array)?; + if key.is_empty() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let mut fields = Vec::with_capacity(array.len()); + while array.len() >= 2 { + let field = take_bulk_string(&mut array)?; + if field.is_empty() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + fields.push(field); + } + + let f = Arc::new(Box::<[ArcByteSlice]>::from(fields)); + Ok(Self { key, fields: f }) + } else { + Err(Error::new(ErrorKind::Other, "malformed command")) + } + } +} + +impl From<&HmGetRequest> for Message { + fn from(other: &HmGetRequest) -> Message { + let mut v = vec![ + Message::bulk_string(b"HMGET"), + Message::BulkString(BulkString::from(other.key.clone())), + ]; + for kv in (*other.fields).iter() { + v.push(Message::BulkString(BulkString::from(kv.clone()))); + } + + Message::Array(Array { inner: Some(v) }) + } +} + +impl Compose for HmGetRequest { + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let message = Message::from(self); + message.compose(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parser() { + let parser = RequestParser::new(); + + //1 field + if let Request::HmGet(request) = parser.parse(b"hmget key field1\r\n").unwrap().into_inner() + { + assert_eq!(request.key(), b"key"); + assert_eq!(request.fields().len(), 1); + assert_eq!(request.fields()[0], b"field1"); + } else { + panic!("invalid parse result"); + } + + //2 fields + if let Request::HmGet(request) = parser + .parse(b"hmget key field1 field2\r\n") + .unwrap() + .into_inner() + { + assert_eq!(request.key(), b"key"); + assert_eq!(request.fields().len(), 2); + assert_eq!(request.fields()[0], b"field1"); + assert_eq!(request.fields()[1], b"field2"); + } else { + panic!("invalid parse result"); + } + + //3 fields + if let Request::HmGet(request) = parser + .parse(b"hmget key field1 field2 42\r\n") + .unwrap() + .into_inner() + { + assert_eq!(request.key(), b"key"); + assert_eq!(request.fields().len(), 3); + assert_eq!(request.fields()[0], b"field1"); + assert_eq!(request.fields()[1], b"field2"); + assert_eq!(request.fields()[2], b"42"); + } else { + panic!("invalid parse result"); + } + + //insufficient whitespace delimited strings + parser + .parse(b"hmget key\r\n") + .expect_err("malformed command"); + } +} diff --git a/src/protocol/resp/src/request/mod.rs b/src/protocol/resp/src/request/mod.rs index 534ccebc8..15b3928a7 100644 --- a/src/protocol/resp/src/request/mod.rs +++ b/src/protocol/resp/src/request/mod.rs @@ -12,10 +12,12 @@ use std::sync::Arc; mod badd; mod get; +mod hmget; mod set; pub use badd::BAddRequest; pub use get::GetRequest; +pub use hmget::HmGetRequest; pub use set::SetRequest; #[derive(Default)] @@ -95,6 +97,9 @@ impl Parse for RequestParser { Some(b"get") | Some(b"GET") => { GetRequest::try_from(message).map(Request::from) } + Some(b"hmget") | Some(b"HMGET") => { + HmGetRequest::try_from(message).map(Request::from) + } Some(b"set") | Some(b"SET") => { SetRequest::try_from(message).map(Request::from) } @@ -120,6 +125,7 @@ impl Compose for Request { match self { Self::BAdd(r) => r.compose(buf), Self::Get(r) => r.compose(buf), + Self::HmGet(r) => r.compose(buf), Self::Set(r) => r.compose(buf), } } @@ -129,6 +135,7 @@ impl Compose for Request { pub enum Request { BAdd(BAddRequest), Get(GetRequest), + HmGet(HmGetRequest), Set(SetRequest), } @@ -144,6 +151,12 @@ impl From for Request { } } +impl From for Request { + fn from(other: HmGetRequest) -> Self { + Self::HmGet(other) + } +} + impl From for Request { fn from(other: SetRequest) -> Self { Self::Set(other) @@ -154,6 +167,7 @@ impl From for Request { pub enum Command { BAdd, Get, + HmGet, Set, } @@ -164,6 +178,7 @@ impl TryFrom<&[u8]> for Command { match other { b"badd" | b"BADD" => Ok(Command::BAdd), b"get" | b"GET" => Ok(Command::Get), + b"hmget" | b"HMGET" => Ok(Command::HmGet), b"set" | b"SET" => Ok(Command::Set), _ => Err(()), }