{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -fno-specialize #-}
{-# OPTIONS_HADDOCK prune #-}

-- | A purely functional implementation of MerkleTrees that is suitable for
-- usage on-chain. Note however that the construction of 'MerkleTree' and
-- membership proofs are still expected to happen *off-chain* while only the
-- proof verification should be done on-chain.
--
-- Note that this module is meant to used as a qualified import, for example:
--
-- @
-- import qualified Plutus.MerkleTree as MT
-- @
module Plutus.MerkleTree where

import PlutusPrelude hiding (toList)

import PlutusTx qualified
import PlutusTx.Builtins (divideInteger)
import PlutusTx.List qualified as List
import PlutusTx.Prelude hiding (toList)

import Data.ByteString.Base16 qualified as Haskell.Base16
import Data.Text qualified as Haskell.Text
import Data.Text.Encoding qualified as Haskell.Text.Encoding
import Prelude qualified as Haskell

-- * MerkleTree

-- | A MerkleTree representation, suitable for on-chain manipulation.
-- Construction of the merkle tree shouldn't be done by hand, but via
-- 'fromList'.
data MerkleTree
  = MerkleEmpty
  | MerkleNode Hash MerkleTree MerkleTree
  | MerkleLeaf Hash BuiltinByteString
  deriving (MerkleTree -> MerkleTree -> Bool
(MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> Bool) -> Eq MerkleTree
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MerkleTree -> MerkleTree -> Bool
== :: MerkleTree -> MerkleTree -> Bool
$c/= :: MerkleTree -> MerkleTree -> Bool
/= :: MerkleTree -> MerkleTree -> Bool
Haskell.Eq, Int -> MerkleTree -> ShowS
[MerkleTree] -> ShowS
MerkleTree -> String
(Int -> MerkleTree -> ShowS)
-> (MerkleTree -> String)
-> ([MerkleTree] -> ShowS)
-> Show MerkleTree
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MerkleTree -> ShowS
showsPrec :: Int -> MerkleTree -> ShowS
$cshow :: MerkleTree -> String
show :: MerkleTree -> String
$cshowList :: [MerkleTree] -> ShowS
showList :: [MerkleTree] -> ShowS
Haskell.Show)

instance Eq MerkleTree where
  MerkleTree
MerkleEmpty == :: MerkleTree -> MerkleTree -> Bool
== MerkleTree
MerkleEmpty = Bool
True
  (MerkleLeaf Hash
h0 BuiltinByteString
_) == (MerkleLeaf Hash
h1 BuiltinByteString
_) = Hash
h0 Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Hash
h1
  (MerkleNode Hash
h0 MerkleTree
_ MerkleTree
_) == (MerkleNode Hash
h1 MerkleTree
_ MerkleTree
_) = Hash
h0 Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Hash
h1
  MerkleTree
_ == MerkleTree
_ = Bool
False

-- | Construct a 'MerkleTree' from a list of serialized data as
-- 'BuiltinByteString'.
--
-- Note that, while this operation is doable on-chain, it is expensive and
-- preferably done off-chain.
fromList :: [BuiltinByteString] -> MerkleTree
fromList :: [BuiltinByteString] -> MerkleTree
fromList [BuiltinByteString]
es0 = Integer -> [BuiltinByteString] -> MerkleTree
recursively ([BuiltinByteString] -> Integer
forall (t :: * -> *) a. Foldable t => t a -> Integer
length [BuiltinByteString]
es0) [BuiltinByteString]
es0
 where
  recursively :: Integer -> [BuiltinByteString] -> MerkleTree
recursively Integer
len =
    \case
      [] ->
        MerkleTree
MerkleEmpty
      [BuiltinByteString
e] ->
        Hash -> BuiltinByteString -> MerkleTree
MerkleLeaf (BuiltinByteString -> Hash
hash BuiltinByteString
e) BuiltinByteString
e
      [BuiltinByteString]
es ->
        let cutoff :: Integer
cutoff = Integer
len Integer -> Integer -> Integer
`divideInteger` Integer
2
            ([BuiltinByteString]
l, [BuiltinByteString]
r) = (Integer -> [BuiltinByteString] -> [BuiltinByteString]
forall a. Integer -> [a] -> [a]
List.take Integer
cutoff [BuiltinByteString]
es, Integer -> [BuiltinByteString] -> [BuiltinByteString]
forall a. Integer -> [a] -> [a]
drop Integer
cutoff [BuiltinByteString]
es)
            lnode :: MerkleTree
lnode = Integer -> [BuiltinByteString] -> MerkleTree
recursively Integer
cutoff [BuiltinByteString]
l
            rnode :: MerkleTree
rnode = Integer -> [BuiltinByteString] -> MerkleTree
recursively (Integer
len Integer -> Integer -> Integer
forall a. AdditiveGroup a => a -> a -> a
- Integer
cutoff) [BuiltinByteString]
r
         in Hash -> MerkleTree -> MerkleTree -> MerkleTree
MerkleNode (Hash -> Hash -> Hash
combineHash (MerkleTree -> Hash
rootHash MerkleTree
lnode) (MerkleTree -> Hash
rootHash MerkleTree
rnode)) MerkleTree
lnode MerkleTree
rnode
{-# INLINEABLE fromList #-}

-- | Deconstruct a 'MerkleTree' back to a list of elements.
--
-- >>> toList (fromList xs) == xs
-- True
toList :: MerkleTree -> [BuiltinByteString]
toList :: MerkleTree -> [BuiltinByteString]
toList = MerkleTree -> [BuiltinByteString]
go
 where
  go :: MerkleTree -> [BuiltinByteString]
go = \case
    MerkleTree
MerkleEmpty -> []
    MerkleLeaf Hash
_ BuiltinByteString
e -> [BuiltinByteString
e]
    MerkleNode Hash
_ MerkleTree
n1 MerkleTree
n2 -> MerkleTree -> [BuiltinByteString]
toList MerkleTree
n1 [BuiltinByteString] -> [BuiltinByteString] -> [BuiltinByteString]
forall a. Semigroup a => a -> a -> a
<> MerkleTree -> [BuiltinByteString]
toList MerkleTree
n2
{-# INLINEABLE toList #-}

-- | Obtain the root hash of a 'MerkleTree'. In particular we have:
--
-- >>> (mt == mt') == (rootHash mt == rootHash mt')
-- True
rootHash :: MerkleTree -> Hash
rootHash :: MerkleTree -> Hash
rootHash = \case
  MerkleTree
MerkleEmpty -> BuiltinByteString -> Hash
hash BuiltinByteString
""
  MerkleLeaf Hash
h BuiltinByteString
_ -> Hash
h
  MerkleNode Hash
h MerkleTree
_ MerkleTree
_ -> Hash
h
{-# INLINEABLE rootHash #-}

-- | Return true if the 'MerkleTree' is empty.
--
-- >>> null mt == (size mt == 0)
-- True
null :: MerkleTree -> Bool
null :: MerkleTree -> Bool
null = \case
  MerkleTree
MerkleEmpty -> Bool
True
  MerkleTree
_ -> Bool
False
{-# INLINEABLE null #-}

-- | Total numbers of leaves in the tree.
size :: MerkleTree -> Integer
size :: MerkleTree -> Integer
size = \case
  MerkleTree
MerkleEmpty -> Integer
0
  MerkleNode Hash
_ MerkleTree
l MerkleTree
r -> MerkleTree -> Integer
size MerkleTree
l Integer -> Integer -> Integer
forall a. AdditiveSemigroup a => a -> a -> a
+ MerkleTree -> Integer
size MerkleTree
r
  MerkleLeaf{} -> Integer
1
{-# INLINEABLE size #-}

-- * Proof

-- | A membership 'Proof'. The type is meant to be opaque.
type Proof = [Either Hash Hash]

-- | Construct a membership 'Proof' from an element and a 'MerkleTree'. Returns
-- 'Nothing' if the element isn't a member of the tree to begin with.
mkProof :: BuiltinByteString -> MerkleTree -> Maybe Proof
mkProof :: BuiltinByteString -> MerkleTree -> Maybe Proof
mkProof BuiltinByteString
e = Proof -> MerkleTree -> Maybe Proof
go []
 where
  he :: Hash
he = BuiltinByteString -> Hash
hash BuiltinByteString
e
  go :: Proof -> MerkleTree -> Maybe Proof
go Proof
es = \case
    MerkleTree
MerkleEmpty -> Maybe Proof
forall a. Maybe a
Nothing
    MerkleLeaf Hash
h BuiltinByteString
_ ->
      if Hash
h Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Hash
he
        then Proof -> Maybe Proof
forall a. a -> Maybe a
Just Proof
es
        else Maybe Proof
forall a. Maybe a
Nothing
    MerkleNode Hash
_ MerkleTree
l MerkleTree
r ->
      Proof -> MerkleTree -> Maybe Proof
go (Hash -> Either Hash Hash
forall a b. b -> Either a b
Right (MerkleTree -> Hash
rootHash MerkleTree
r) Either Hash Hash -> Proof -> Proof
forall a. a -> [a] -> [a]
: Proof
es) MerkleTree
l Maybe Proof -> Maybe Proof -> Maybe Proof
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Proof -> MerkleTree -> Maybe Proof
go (Hash -> Either Hash Hash
forall a b. a -> Either a b
Left (MerkleTree -> Hash
rootHash MerkleTree
l) Either Hash Hash -> Proof -> Proof
forall a. a -> [a] -> [a]
: Proof
es) MerkleTree
r
{-# INLINEABLE mkProof #-}

-- | Check whether a element is part of a 'MerkleTree' using only its root hash
-- and a 'Proof'. The proof is guaranteed to be in log(n) of the size of the
-- tree, which is why we are interested in such data-structure in the first
-- place.
member :: BuiltinByteString -> Hash -> Proof -> Bool
member :: BuiltinByteString -> Hash -> Proof -> Bool
member BuiltinByteString
e Hash
root = Hash -> Proof -> Bool
go (BuiltinByteString -> Hash
hash BuiltinByteString
e)
 where
  go :: Hash -> Proof -> Bool
go Hash
root' = \case
    [] -> Hash
root' Hash -> Hash -> Bool
forall a. Eq a => a -> a -> Bool
== Hash
root
    Left Hash
l : Proof
q -> Hash -> Proof -> Bool
go (Hash -> Hash -> Hash
combineHash Hash
l Hash
root') Proof
q
    Right Hash
r : Proof
q -> Hash -> Proof -> Bool
go (Hash -> Hash -> Hash
combineHash Hash
root' Hash
r) Proof
q
{-# INLINEABLE member #-}

-- * Hash

-- | A type for representing hash digests.
newtype Hash = Hash BuiltinByteString
  deriving (Hash -> Hash -> Bool
(Hash -> Hash -> Bool) -> (Hash -> Hash -> Bool) -> Eq Hash
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Hash -> Hash -> Bool
== :: Hash -> Hash -> Bool
$c/= :: Hash -> Hash -> Bool
/= :: Hash -> Hash -> Bool
Haskell.Eq)

instance Eq Hash where
  Hash BuiltinByteString
h == :: Hash -> Hash -> Bool
== Hash BuiltinByteString
h' = BuiltinByteString
h BuiltinByteString -> BuiltinByteString -> Bool
forall a. Eq a => a -> a -> Bool
== BuiltinByteString
h'

instance Haskell.Show Hash where
  show :: Hash -> String
show (Hash BuiltinByteString
bs) =
    Text -> String
Haskell.Text.unpack
      (Text -> String)
-> (BuiltinByteString -> Text) -> BuiltinByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
Haskell.Text.Encoding.decodeUtf8
      (ByteString -> Text)
-> (BuiltinByteString -> ByteString) -> BuiltinByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Haskell.Base16.encode
      (ByteString -> ByteString)
-> (BuiltinByteString -> ByteString)
-> BuiltinByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuiltinByteString -> ByteString
forall arep a. FromBuiltin arep a => arep -> a
fromBuiltin
      (BuiltinByteString -> ByteString)
-> (BuiltinByteString -> BuiltinByteString)
-> BuiltinByteString
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> BuiltinByteString -> BuiltinByteString
takeByteString Integer
4
      (BuiltinByteString -> String) -> BuiltinByteString -> String
forall a b. (a -> b) -> a -> b
$ BuiltinByteString
bs

-- | Computes a SHA-256 hash of a given 'BuiltinByteString' message.
hash :: BuiltinByteString -> Hash
hash :: BuiltinByteString -> Hash
hash = BuiltinByteString -> Hash
Hash (BuiltinByteString -> Hash)
-> (BuiltinByteString -> BuiltinByteString)
-> BuiltinByteString
-> Hash
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuiltinByteString -> BuiltinByteString
sha2_256
{-# INLINEABLE hash #-}

-- | Combines two hashes digest into a new one. This is effectively a new hash
-- digest of the same length.
combineHash :: Hash -> Hash -> Hash
combineHash :: Hash -> Hash -> Hash
combineHash (Hash BuiltinByteString
h) (Hash BuiltinByteString
h') = BuiltinByteString -> Hash
hash (BuiltinByteString -> BuiltinByteString -> BuiltinByteString
appendByteString BuiltinByteString
h BuiltinByteString
h')
{-# INLINEABLE combineHash #-}

-- Template Haskell

PlutusTx.unstableMakeIsData ''Hash