use crate::rstd::{
boxed::Box, convert::TryInto, marker::PhantomData, ops::Range, vec, vec::Vec,
};
use hash_db::Hasher;
use crate::{
CError, ChildReference, nibble::LeftNibbleSlice, nibble_ops::NIBBLE_LENGTH, NibbleSlice, node::{NodeHandle, NodeHandlePlan, NodePlan, OwnedNode}, NodeCodec, Recorder,
Result as TrieResult, Trie, TrieError, TrieHash,
TrieLayout,
};
struct StackEntry<'a, C: NodeCodec> {
prefix: LeftNibbleSlice<'a>,
node: OwnedNode<Vec<u8>>,
node_hash: Option<C::HashOut>,
omit_value: bool,
child_index: usize,
children: Vec<Option<ChildReference<C::HashOut>>>,
output_index: Option<usize>,
_marker: PhantomData<C>,
}
impl<'a, C: NodeCodec> StackEntry<'a, C> {
fn new(
prefix: LeftNibbleSlice<'a>,
node_data: Vec<u8>,
node_hash: Option<C::HashOut>,
output_index: Option<usize>,
) -> TrieResult<Self, C::HashOut, C::Error>
{
let node = OwnedNode::new::<C>(node_data)
.map_err(|err| Box::new(
TrieError::DecoderError(node_hash.unwrap_or_default(), err)
))?;
let children_len = match node.node_plan() {
NodePlan::Empty | NodePlan::Leaf { .. } => 0,
NodePlan::Extension { .. } => 1,
NodePlan::Branch { .. } | NodePlan::NibbledBranch { .. } => NIBBLE_LENGTH,
};
Ok(StackEntry {
prefix,
node,
node_hash,
omit_value: false,
child_index: 0,
children: vec![None; children_len],
output_index,
_marker: PhantomData::default(),
})
}
fn encode_node(mut self) -> TrieResult<Vec<u8>, C::HashOut, C::Error> {
let node_data = self.node.data();
Ok(match self.node.node_plan() {
NodePlan::Empty => node_data.to_vec(),
NodePlan::Leaf { .. } if !self.omit_value => node_data.to_vec(),
NodePlan::Leaf { partial, value: _ } => {
let partial = partial.build(node_data);
C::leaf_node(partial.right(), &[])
}
NodePlan::Extension { .. } if self.child_index == 0 => node_data.to_vec(),
NodePlan::Extension { partial: partial_plan, child: _ } => {
let partial = partial_plan.build(node_data);
let child = self.children[0]
.expect(
"for extension nodes, children[0] is guaranteed to be Some when \
child_index > 0; \
the branch guard guarantees that child_index > 0"
);
C::extension_node(
partial.right_iter(),
partial.len(),
child
)
}
NodePlan::Branch { value, children } => {
Self::complete_branch_children(
node_data,
children,
self.child_index,
&mut self.children
)?;
C::branch_node(
self.children.into_iter(),
value_with_omission(node_data, value, self.omit_value)
)
},
NodePlan::NibbledBranch { partial: partial_plan, value, children } => {
let partial = partial_plan.build(node_data);
Self::complete_branch_children(
node_data,
children,
self.child_index,
&mut self.children
)?;
C::branch_node_nibbled(
partial.right_iter(),
partial.len(),
self.children.into_iter(),
value_with_omission(node_data, value, self.omit_value)
)
},
})
}
fn complete_branch_children(
node_data: &[u8],
child_handles: &[Option<NodeHandlePlan>; NIBBLE_LENGTH],
child_index: usize,
children: &mut [Option<ChildReference<C::HashOut>>],
) -> TrieResult<(), C::HashOut, C::Error>
{
for i in child_index..NIBBLE_LENGTH {
children[i] = child_handles[i]
.as_ref()
.map(|child_plan|
child_plan
.build(node_data)
.try_into()
.map_err(|hash| Box::new(
TrieError::InvalidHash(C::HashOut::default(), hash)
))
)
.transpose()?;
}
Ok(())
}
fn set_child(&mut self, encoded_child: &[u8]) {
let child_ref = match self.node.node_plan() {
NodePlan::Empty | NodePlan::Leaf { .. } => panic!(
"empty and leaf nodes have no children; \
thus they are never descended into; \
thus set_child will not be called on an entry with one of these types"
),
NodePlan::Extension { child, .. } => {
assert_eq!(
self.child_index, 0,
"extension nodes only have one child; \
set_child is called when the only child is popped from the stack; \
child_index is 0 before child is pushed to the stack; qed"
);
Some(Self::replacement_child_ref(encoded_child, child))
}
NodePlan::Branch { children, .. } | NodePlan::NibbledBranch { children, .. } => {
assert!(
self.child_index < NIBBLE_LENGTH,
"extension nodes have at most NIBBLE_LENGTH children; \
set_child is called when the only child is popped from the stack; \
child_index is <NIBBLE_LENGTH before child is pushed to the stack; qed"
);
children[self.child_index]
.as_ref()
.map(|child| Self::replacement_child_ref(encoded_child, child))
}
};
self.children[self.child_index] = child_ref;
self.child_index += 1;
}
fn replacement_child_ref(encoded_child: &[u8], child: &NodeHandlePlan)
-> ChildReference<C::HashOut>
{
match child {
NodeHandlePlan::Hash(_) => ChildReference::Inline(C::HashOut::default(), 0),
NodeHandlePlan::Inline(_) => {
let mut hash = C::HashOut::default();
assert!(
encoded_child.len() <= hash.as_ref().len(),
"the encoding of the raw inline node is checked to be at most the hash length
before descending; \
the encoding of the proof node is always smaller than the raw node as data is \
only stripped"
);
&mut hash.as_mut()[..encoded_child.len()].copy_from_slice(encoded_child);
ChildReference::Inline(hash, encoded_child.len())
}
}
}
}
pub fn generate_proof<'a, T, L, I, K>(trie: &T, keys: I)
-> TrieResult<Vec<Vec<u8>>, TrieHash<L>, CError<L>>
where
T: Trie<L>,
L: TrieLayout,
I: IntoIterator<Item=&'a K>,
K: 'a + AsRef<[u8]>
{
let mut keys = keys.into_iter()
.map(|key| key.as_ref())
.collect::<Vec<_>>();
keys.sort();
keys.dedup();
let mut stack = <Vec<StackEntry<L::Codec>>>::new();
let mut proof_nodes = Vec::new();
for key_bytes in keys {
let key = LeftNibbleSlice::new(key_bytes);
unwind_stack(&mut stack, &mut proof_nodes, Some(&key))?;
let mut recorder = Recorder::new();
let expected_value = trie.get_with(key_bytes, &mut recorder)?;
let mut recorded_nodes = recorder.drain().into_iter().peekable();
{
let mut stack_iter = stack.iter().peekable();
while let (Some(next_record), Some(next_entry)) =
(recorded_nodes.peek(), stack_iter.peek())
{
if next_entry.node_hash != Some(next_record.hash) {
break;
}
recorded_nodes.next();
stack_iter.next();
}
}
loop {
let step = match stack.last_mut() {
Some(entry) => match_key_to_node::<L::Codec>(
entry.node.data(),
entry.node.node_plan(),
&mut entry.omit_value,
&mut entry.child_index,
&mut entry.children,
&key,
entry.prefix.len(),
)?,
None => Step::Descend {
child_prefix_len: 0,
child: NodeHandle::Hash(trie.root().as_ref()),
},
};
match step {
Step::Descend { child_prefix_len, child } => {
let child_prefix = key.truncate(child_prefix_len);
let child_entry = match child {
NodeHandle::Hash(hash) => {
let child_record = recorded_nodes.next()
.expect(
"this function's trie traversal logic mirrors that of Lookup; \
thus the sequence of traversed nodes must be the same; \
so the next child node must have been recorded and must have \
the expected hash"
);
assert_eq!(child_record.hash.as_ref(), hash);
let output_index = proof_nodes.len();
proof_nodes.push(Vec::new());
StackEntry::new(
child_prefix,
child_record.data,
Some(child_record.hash),
Some(output_index),
)?
}
NodeHandle::Inline(data) => {
if data.len() > L::Hash::LENGTH {
return Err(Box::new(
TrieError::InvalidHash(<TrieHash<L>>::default(), data.to_vec())
));
}
StackEntry::new(
child_prefix,
data.to_vec(),
None,
None,
)?
}
};
stack.push(child_entry);
}
Step::FoundValue(value) => {
assert_eq!(
value,
expected_value.as_ref().map(|v| v.as_ref()),
"expected_value is found using `trie_db::Lookup`; \
value is found by traversing the same nodes recorded during the lookup \
using the same logic; \
thus the values found must be equal"
);
assert!(
recorded_nodes.next().is_none(),
"the recorded nodes are only recorded on the lookup path to the current \
key; \
recorded nodes is the minimal sequence of trie nodes on the lookup path; \
the value was found by traversing recorded nodes, so there must be none \
remaining"
);
break;
}
}
}
}
unwind_stack(&mut stack, &mut proof_nodes, None)?;
Ok(proof_nodes)
}
enum Step<'a> {
Descend {
child_prefix_len: usize,
child: NodeHandle<'a>,
},
FoundValue(Option<&'a [u8]>),
}
fn match_key_to_node<'a, C: NodeCodec>(
node_data: &'a [u8],
node_plan: &NodePlan,
omit_value: &mut bool,
child_index: &mut usize,
children: &mut [Option<ChildReference<C::HashOut>>],
key: &LeftNibbleSlice,
prefix_len: usize,
) -> TrieResult<Step<'a>, C::HashOut, C::Error>
{
Ok(match node_plan {
NodePlan::Empty => Step::FoundValue(None),
NodePlan::Leaf { partial: partial_plan, value: value_range } => {
let partial = partial_plan.build(node_data);
if key.contains(&partial, prefix_len) &&
key.len() == prefix_len + partial.len()
{
*omit_value = true;
Step::FoundValue(Some(&node_data[value_range.clone()]))
} else {
Step::FoundValue(None)
}
}
NodePlan::Extension { partial: partial_plan, child: child_plan } => {
let partial = partial_plan.build(node_data);
if key.contains(&partial, prefix_len) {
assert_eq!(*child_index, 0);
let child_prefix_len = prefix_len + partial.len();
let child = child_plan.build(&node_data);
Step::Descend { child_prefix_len, child }
} else {
Step::FoundValue(None)
}
}
NodePlan::Branch { value, children: child_handles } =>
match_key_to_branch_node::<C>(
node_data,
value,
&child_handles,
omit_value,
child_index,
children,
key,
prefix_len,
NibbleSlice::new(&[]),
)?,
NodePlan::NibbledBranch { partial: partial_plan, value, children: child_handles } =>
match_key_to_branch_node::<C>(
node_data,
value,
&child_handles,
omit_value,
child_index,
children,
key,
prefix_len,
partial_plan.build(node_data),
)?,
})
}
fn match_key_to_branch_node<'a, 'b, C: NodeCodec>(
node_data: &'a [u8],
value_range: &'b Option<Range<usize>>,
child_handles: &'b [Option<NodeHandlePlan>; NIBBLE_LENGTH],
omit_value: &mut bool,
child_index: &mut usize,
children: &mut [Option<ChildReference<C::HashOut>>],
key: &'b LeftNibbleSlice<'b>,
prefix_len: usize,
partial: NibbleSlice<'b>,
) -> TrieResult<Step<'a>, C::HashOut, C::Error>
{
if !key.contains(&partial, prefix_len) {
return Ok(Step::FoundValue(None));
}
if key.len() == prefix_len + partial.len() {
*omit_value = true;
let value = value_range.clone().map(|range| &node_data[range]);
return Ok(Step::FoundValue(value));
}
let new_index = key.at(prefix_len + partial.len())
.expect(
"key contains partial key after entry key offset; \
thus key len is greater than equal to entry key len plus partial key len; \
also they are unequal due to else condition;
qed"
)
as usize;
assert!(*child_index <= new_index);
while *child_index < new_index {
children[*child_index] = child_handles[*child_index]
.as_ref()
.map(|child_plan|
child_plan
.build(node_data)
.try_into()
.map_err(|hash| Box::new(
TrieError::InvalidHash(C::HashOut::default(), hash)
))
)
.transpose()?;
*child_index += 1;
}
if let Some(child_plan) = &child_handles[*child_index] {
Ok(Step::Descend {
child_prefix_len: prefix_len + partial.len() + 1,
child: child_plan.build(node_data),
})
} else {
Ok(Step::FoundValue(None))
}
}
fn value_with_omission<'a>(
node_data: &'a [u8],
value_range: &Option<Range<usize>>,
omit: bool
) -> Option<&'a [u8]>
{
if omit {
None
} else {
value_range.clone().map(|range| &node_data[range])
}
}
fn unwind_stack<C: NodeCodec>(
stack: &mut Vec<StackEntry<C>>,
proof_nodes: &mut Vec<Vec<u8>>,
maybe_key: Option<&LeftNibbleSlice>,
) -> TrieResult<(), C::HashOut, C::Error>
{
while let Some(entry) = stack.pop() {
match maybe_key {
Some(key) if key.starts_with(&entry.prefix) => {
stack.push(entry);
break;
}
_ => {
let index = entry.output_index;
let encoded = entry.encode_node()?;
if let Some(parent_entry) = stack.last_mut() {
parent_entry.set_child(&encoded);
}
if let Some(index) = index {
proof_nodes[index] = encoded;
}
}
}
}
Ok(())
}