-- |
-- Module      : Crypto.PubKey.ElGamal
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
-- This module is a work in progress. do not use:
-- it might eat your dog, your data or even both.
--
-- TODO: provide a mapping between integer and ciphertext
--       generate numbers correctly
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.PubKey.ElGamal
    ( Params
    , PublicNumber
    , PrivateNumber
    , EphemeralKey(..)
    , SharedKey
    , Signature
    -- * Generation
    , generatePrivate
    , generatePublic
    -- * Encryption and decryption with no scheme
    , encryptWith
    , encrypt
    , decrypt
    -- * Signature primitives
    , signWith
    , sign
    -- * Verification primitives
    , verify
    ) where

import Data.Maybe (fromJust)
import Crypto.Internal.Imports
import Crypto.Internal.ByteArray (ByteArrayAccess)
import Crypto.Number.ModArithmetic (expSafe, expFast, inverse)
import Crypto.Number.Generate (generateMax)
import Crypto.Number.Serialize (os2ip)
import Crypto.Number.Basic (gcde)
import Crypto.Random.Types
import Crypto.PubKey.DH (PrivateNumber(..), PublicNumber(..), Params(..), SharedKey(..))
import Crypto.Hash

-- | ElGamal Signature
data Signature = Signature (Integer, Integer)

-- | ElGamal Ephemeral key. also called Temporary key.
newtype EphemeralKey = EphemeralKey Integer
    deriving (EphemeralKey -> ()
forall a. (a -> ()) -> NFData a
rnf :: EphemeralKey -> ()
$crnf :: EphemeralKey -> ()
NFData)

-- | generate a private number with no specific property
-- this number is usually called a and need to be between
-- 0 and q (order of the group G).
--
generatePrivate :: MonadRandom m => Integer -> m PrivateNumber
generatePrivate :: forall (m :: * -> *). MonadRandom m => Integer -> m PrivateNumber
generatePrivate Integer
q = Integer -> PrivateNumber
PrivateNumber forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax Integer
q

-- | generate an ephemeral key which is a number with no specific property,
-- and need to be between 0 and q (order of the group G).
--
generateEphemeral :: MonadRandom m => Integer -> m EphemeralKey
generateEphemeral :: forall (m :: * -> *). MonadRandom m => Integer -> m EphemeralKey
generateEphemeral Integer
q = PrivateNumber -> EphemeralKey
toEphemeral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadRandom m => Integer -> m PrivateNumber
generatePrivate Integer
q
    where toEphemeral :: PrivateNumber -> EphemeralKey
toEphemeral (PrivateNumber Integer
n) = Integer -> EphemeralKey
EphemeralKey Integer
n

-- | generate a public number that is for the other party benefits.
-- this number is usually called h=g^a
generatePublic :: Params -> PrivateNumber -> PublicNumber
generatePublic :: Params -> PrivateNumber -> PublicNumber
generatePublic (Params Integer
p Integer
g Int
_) (PrivateNumber Integer
a) = Integer -> PublicNumber
PublicNumber forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
a Integer
p

-- | encrypt with a specified ephemeral key
-- do not reuse ephemeral key.
encryptWith :: EphemeralKey -> Params -> PublicNumber -> Integer -> (Integer,Integer)
encryptWith :: EphemeralKey
-> Params -> PublicNumber -> Integer -> (Integer, Integer)
encryptWith (EphemeralKey Integer
b) (Params Integer
p Integer
g Int
_) (PublicNumber Integer
h) Integer
m = (Integer
c1,Integer
c2)
    where s :: Integer
s  = Integer -> Integer -> Integer -> Integer
expSafe Integer
h Integer
b Integer
p
          c1 :: Integer
c1 = Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
b Integer
p
          c2 :: Integer
c2 = (Integer
s forall a. Num a => a -> a -> a
* Integer
m) forall a. Integral a => a -> a -> a
`mod` Integer
p

-- | encrypt a message using params and public keys
-- will generate b (called the ephemeral key)
encrypt :: MonadRandom m => Params -> PublicNumber -> Integer -> m (Integer,Integer)
encrypt :: forall (m :: * -> *).
MonadRandom m =>
Params -> PublicNumber -> Integer -> m (Integer, Integer)
encrypt params :: Params
params@(Params Integer
p Integer
_ Int
_) PublicNumber
public Integer
m = (\EphemeralKey
b -> EphemeralKey
-> Params -> PublicNumber -> Integer -> (Integer, Integer)
encryptWith EphemeralKey
b Params
params PublicNumber
public Integer
m) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadRandom m => Integer -> m EphemeralKey
generateEphemeral Integer
q
    where q :: Integer
q = Integer
pforall a. Num a => a -> a -> a
-Integer
1 -- p is prime, hence order of the group is p-1

-- | decrypt message
decrypt :: Params -> PrivateNumber -> (Integer, Integer) -> Integer
decrypt :: Params -> PrivateNumber -> (Integer, Integer) -> Integer
decrypt (Params Integer
p Integer
_ Int
_) (PrivateNumber Integer
a) (Integer
c1,Integer
c2) = (Integer
c2 forall a. Num a => a -> a -> a
* Integer
sm1) forall a. Integral a => a -> a -> a
`mod` Integer
p
    where s :: Integer
s   = Integer -> Integer -> Integer -> Integer
expSafe Integer
c1 Integer
a Integer
p
          sm1 :: Integer
sm1 = forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Maybe Integer
inverse Integer
s Integer
p -- always inversible in Zp

-- | sign a message with an explicit k number
--
-- if k is not appropriate, then no signature is returned.
--
-- with some appropriate value of k, the signature generation can fail,
-- and no signature is returned. User of this function need to retry
-- with a different k value.
signWith :: (ByteArrayAccess msg, HashAlgorithm hash)
         => Integer         -- ^ random number k, between 0 and p-1 and gcd(k,p-1)=1
         -> Params          -- ^ DH params (p,g)
         -> PrivateNumber   -- ^ DH private key
         -> hash            -- ^ collision resistant hash algorithm
         -> msg             -- ^ message to sign
         -> Maybe Signature
signWith :: forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Integer
-> Params -> PrivateNumber -> hash -> msg -> Maybe Signature
signWith Integer
k (Params Integer
p Integer
g Int
_) (PrivateNumber Integer
x) hash
hashAlg msg
msg
    | Integer
k forall a. Ord a => a -> a -> Bool
>= Integer
pforall a. Num a => a -> a -> a
-Integer
1 Bool -> Bool -> Bool
|| Integer
d forall a. Ord a => a -> a -> Bool
> Integer
1 = forall a. Maybe a
Nothing -- gcd(k,p-1) is not 1
    | Integer
s forall a. Eq a => a -> a -> Bool
== Integer
0            = forall a. Maybe a
Nothing
    | Bool
otherwise         = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ (Integer, Integer) -> Signature
Signature (Integer
r,Integer
s)
    where r :: Integer
r          = Integer -> Integer -> Integer -> Integer
expSafe Integer
g Integer
k Integer
p
          h :: Integer
h          = forall ba. ByteArrayAccess ba => ba -> Integer
os2ip forall a b. (a -> b) -> a -> b
$ forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith hash
hashAlg msg
msg
          s :: Integer
s          = ((Integer
h forall a. Num a => a -> a -> a
- Integer
xforall a. Num a => a -> a -> a
*Integer
r) forall a. Num a => a -> a -> a
* Integer
kInv) forall a. Integral a => a -> a -> a
`mod` (Integer
pforall a. Num a => a -> a -> a
-Integer
1)
          (Integer
kInv,Integer
_,Integer
d) = Integer -> Integer -> (Integer, Integer, Integer)
gcde Integer
k (Integer
pforall a. Num a => a -> a -> a
-Integer
1)

-- | sign message
--
-- This function will generate a random number, however
-- as the signature might fail, the function will automatically retry
-- until a proper signature has been created.
--
sign :: (ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m)
     => Params         -- ^ DH params (p,g)
     -> PrivateNumber  -- ^ DH private key
     -> hash           -- ^ collision resistant hash algorithm
     -> msg            -- ^ message to sign
     -> m Signature
sign :: forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
Params -> PrivateNumber -> hash -> msg -> m Signature
sign params :: Params
params@(Params Integer
p Integer
_ Int
_) PrivateNumber
priv hash
hashAlg msg
msg = do
    Integer
k <- forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax (Integer
pforall a. Num a => a -> a -> a
-Integer
1)
    case forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Integer
-> Params -> PrivateNumber -> hash -> msg -> Maybe Signature
signWith Integer
k Params
params PrivateNumber
priv hash
hashAlg msg
msg of
        Maybe Signature
Nothing  -> forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
Params -> PrivateNumber -> hash -> msg -> m Signature
sign Params
params PrivateNumber
priv hash
hashAlg msg
msg
        Just Signature
sig -> forall (m :: * -> *) a. Monad m => a -> m a
return Signature
sig

-- | verify a signature
verify :: (ByteArrayAccess msg, HashAlgorithm hash)
       => Params
       -> PublicNumber
       -> hash
       -> msg
       -> Signature
       -> Bool
verify :: forall msg hash.
(ByteArrayAccess msg, HashAlgorithm hash) =>
Params -> PublicNumber -> hash -> msg -> Signature -> Bool
verify (Params Integer
p Integer
g Int
_) (PublicNumber Integer
y) hash
hashAlg msg
msg (Signature (Integer
r,Integer
s))
    | forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Integer
r forall a. Ord a => a -> a -> Bool
<= Integer
0,Integer
r forall a. Ord a => a -> a -> Bool
>= Integer
p,Integer
s forall a. Ord a => a -> a -> Bool
<= Integer
0,Integer
s forall a. Ord a => a -> a -> Bool
>= (Integer
pforall a. Num a => a -> a -> a
-Integer
1)] = Bool
False
    | Bool
otherwise                            = Integer
lhs forall a. Eq a => a -> a -> Bool
== Integer
rhs
    where h :: Integer
h   = forall ba. ByteArrayAccess ba => ba -> Integer
os2ip forall a b. (a -> b) -> a -> b
$ forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith hash
hashAlg msg
msg
          lhs :: Integer
lhs = Integer -> Integer -> Integer -> Integer
expFast Integer
g Integer
h Integer
p
          rhs :: Integer
rhs = (Integer -> Integer -> Integer -> Integer
expFast Integer
y Integer
r Integer
p forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
expFast Integer
r Integer
s Integer
p) forall a. Integral a => a -> a -> a
`mod` Integer
p