Crypto++
pssr.cpp
1 // pssr.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "pssr.h"
5 #include <functional>
6 
7 NAMESPACE_BEGIN(CryptoPP)
8 
9 // more in dll.cpp
10 template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31;
11 template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32;
12 template<> const byte EMSA2HashId<Whirlpool>::id = 0x37;
13 
14 #ifndef CRYPTOPP_IMPORTS
15 
16 size_t PSSR_MEM_Base::MinRepresentativeBitLength(size_t hashIdentifierLength, size_t digestLength) const
17 {
18  size_t saltLen = SaltLen(digestLength);
19  size_t minPadLen = MinPadLen(digestLength);
20  return 9 + 8*(minPadLen + saltLen + digestLength + hashIdentifierLength);
21 }
22 
23 size_t PSSR_MEM_Base::MaxRecoverableLength(size_t representativeBitLength, size_t hashIdentifierLength, size_t digestLength) const
24 {
25  if (AllowRecovery())
26  return SaturatingSubtract(representativeBitLength, MinRepresentativeBitLength(hashIdentifierLength, digestLength)) / 8;
27  return 0;
28 }
29 
30 bool PSSR_MEM_Base::IsProbabilistic() const
31 {
32  return SaltLen(1) > 0;
33 }
34 
35 bool PSSR_MEM_Base::AllowNonrecoverablePart() const
36 {
37  return true;
38 }
39 
40 bool PSSR_MEM_Base::RecoverablePartFirst() const
41 {
42  return false;
43 }
44 
45 void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng,
46  const byte *recoverableMessage, size_t recoverableMessageLength,
47  HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
48  byte *representative, size_t representativeBitLength) const
49 {
50  assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
51 
52  const size_t u = hashIdentifier.second + 1;
53  const size_t representativeByteLength = BitsToBytes(representativeBitLength);
54  const size_t digestSize = hash.DigestSize();
55  const size_t saltSize = SaltLen(digestSize);
56  byte *const h = representative + representativeByteLength - u - digestSize;
57 
58  SecByteBlock digest(digestSize), salt(saltSize);
59  hash.Final(digest);
60  rng.GenerateBlock(salt, saltSize);
61 
62  // compute H = hash of M'
63  byte c[8];
64  PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
65  PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
66  hash.Update(c, 8);
67  hash.Update(recoverableMessage, recoverableMessageLength);
68  hash.Update(digest, digestSize);
69  hash.Update(salt, saltSize);
70  hash.Final(h);
71 
72  // compute representative
73  GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false);
74  byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1;
75  xorStart[0] ^= 1;
76  xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength);
77  xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size());
78  memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second);
79  representative[representativeByteLength - 1] = hashIdentifier.second ? 0xcc : 0xbc;
80  if (representativeBitLength % 8 != 0)
81  representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
82 }
83 
84 DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative(
85  HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
86  byte *representative, size_t representativeBitLength,
87  byte *recoverableMessage) const
88 {
89  assert(representativeBitLength >= MinRepresentativeBitLength(hashIdentifier.second, hash.DigestSize()));
90 
91  const size_t u = hashIdentifier.second + 1;
92  const size_t representativeByteLength = BitsToBytes(representativeBitLength);
93  const size_t digestSize = hash.DigestSize();
94  const size_t saltSize = SaltLen(digestSize);
95  const byte *const h = representative + representativeByteLength - u - digestSize;
96 
97  SecByteBlock digest(digestSize);
98  hash.Final(digest);
99 
100  DecodingResult result(0);
101  bool &valid = result.isValidCoding;
102  size_t &recoverableMessageLength = result.messageLength;
103 
104  valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid;
105  valid = VerifyBufsEqual(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) && valid;
106 
107  GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize);
108  if (representativeBitLength % 8 != 0)
109  representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
110 
111  // extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt
112  byte *salt = representative + representativeByteLength - u - digestSize - saltSize;
113  byte *M = std::find_if(representative, salt-1, std::bind2nd(std::not_equal_to<byte>(), 0));
114  recoverableMessageLength = salt-M-1;
115  if (*M == 0x01
116  && (size_t)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize)
117  && recoverableMessageLength <= MaxRecoverableLength(representativeBitLength, hashIdentifier.second, digestSize))
118  {
119  memcpy(recoverableMessage, M+1, recoverableMessageLength);
120  }
121  else
122  {
123  recoverableMessageLength = 0;
124  valid = false;
125  }
126 
127  // verify H = hash of M'
128  byte c[8];
129  PutWord(false, BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
130  PutWord(false, BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
131  hash.Update(c, 8);
132  hash.Update(recoverableMessage, recoverableMessageLength);
133  hash.Update(digest, digestSize);
134  hash.Update(salt, saltSize);
135  valid = hash.Verify(h) && valid;
136 
137  if (!AllowRecovery() && valid && recoverableMessageLength != 0)
138  {throw NotImplemented("PSSR_MEM: message recovery disabled");}
139 
140  return result;
141 }
142 
143 #endif
144 
145 NAMESPACE_END