module Data.Conduit.Zstd
    ( compress
    , decompress
    ) where

import qualified Data.Conduit as C
import qualified Data.ByteString as B
import qualified Codec.Compression.Zstd.Streaming as Z
import           Control.Monad.IO.Class (MonadIO, liftIO)
import           Control.Exception.Base (throwIO)
import           System.IO.Error (userError)
import           Data.Maybe (fromMaybe)


-- | compression conduit
compress :: MonadIO m =>
        Int -- ^ compression level
        -> C.ConduitT B.ByteString B.ByteString m ()
compress :: forall (m :: * -> *).
MonadIO m =>
Int -> ConduitT ByteString ByteString m ()
compress Int
level = IO Result -> ConduitT ByteString ByteString m Result
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Int -> IO Result
Z.compress Int
level) ConduitT ByteString ByteString m Result
-> (Result -> ConduitT ByteString ByteString m ())
-> ConduitT ByteString ByteString m ()
forall a b.
ConduitT ByteString ByteString m a
-> (a -> ConduitT ByteString ByteString m b)
-> ConduitT ByteString ByteString m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go

-- | decompression conduit
decompress :: MonadIO m => C.ConduitT B.ByteString B.ByteString m ()
decompress :: forall (m :: * -> *).
MonadIO m =>
ConduitT ByteString ByteString m ()
decompress = IO Result -> ConduitT ByteString ByteString m Result
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Result
Z.decompress ConduitT ByteString ByteString m Result
-> (Result -> ConduitT ByteString ByteString m ())
-> ConduitT ByteString ByteString m ()
forall a b.
ConduitT ByteString ByteString m a
-> (a -> ConduitT ByteString ByteString m b)
-> ConduitT ByteString ByteString m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go

go :: MonadIO m => Z.Result -> C.ConduitT B.ByteString B.ByteString m ()
go :: forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go (Z.Produce ByteString
r IO Result
next) = do
    ByteString -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield ByteString
r
    IO Result -> ConduitT ByteString ByteString m Result
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Result
next ConduitT ByteString ByteString m Result
-> (Result -> ConduitT ByteString ByteString m ())
-> ConduitT ByteString ByteString m ()
forall a b.
ConduitT ByteString ByteString m a
-> (a -> ConduitT ByteString ByteString m b)
-> ConduitT ByteString ByteString m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go
go input :: Result
input@(Z.Consume ByteString -> IO Result
f) = do
    Maybe ByteString
next <- ConduitT ByteString ByteString m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
C.await
    case Maybe ByteString
next of
      Just ByteString
chunk | ByteString -> Bool
B.null ByteString
chunk ->
        Result -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go Result
input
      Maybe ByteString
_ ->
        IO Result -> ConduitT ByteString ByteString m Result
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (ByteString -> IO Result
f (ByteString -> IO Result) -> ByteString -> IO Result
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
B.empty Maybe ByteString
next) ConduitT ByteString ByteString m Result
-> (Result -> ConduitT ByteString ByteString m ())
-> ConduitT ByteString ByteString m ()
forall a b.
ConduitT ByteString ByteString m a
-> (a -> ConduitT ByteString ByteString m b)
-> ConduitT ByteString ByteString m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Result -> ConduitT ByteString ByteString m ()
go
go (Z.Error [Char]
m [Char]
e) = IO () -> ConduitT ByteString ByteString m ()
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char]
"ZStandard error :" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
m [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
e))
go (Z.Done ByteString
r) = ByteString -> ConduitT ByteString ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
C.yield ByteString
r