Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

oaep.cpp

00001 // oaep.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 #include "oaep.h"
00005 
00006 #include <functional>
00007 
00008 NAMESPACE_BEGIN(CryptoPP)
00009 
00010 // ********************************************************
00011 
00012 ANONYMOUS_NAMESPACE_BEGIN
00013         template <class H, byte *P, unsigned int PLen>
00014         struct PHashComputation
00015         {
00016                 PHashComputation()      {H().CalculateDigest(pHash, P, PLen);}
00017                 byte pHash[H::DIGESTSIZE];
00018         };
00019 
00020         template <class H, byte *P, unsigned int PLen>
00021         const byte *PHash()
00022         {
00023                 static PHashComputation<H,P,PLen> pHash;
00024                 return pHash.pHash;
00025         }
00026 NAMESPACE_END
00027 
00028 template <class H, class MGF, byte *P, unsigned int PLen>
00029 unsigned int OAEP<H,MGF,P,PLen>::MaxUnpaddedLength(unsigned int paddedLength) const
00030 {
00031         return paddedLength/8 > 1+2*H::DIGESTSIZE ? paddedLength/8-1-2*H::DIGESTSIZE : 0;
00032 }
00033 
00034 template <class H, class MGF, byte *P, unsigned int PLen>
00035 void OAEP<H,MGF,P,PLen>::Pad(RandomNumberGenerator &rng, const byte *input, unsigned int inputLength, byte *oaepBlock, unsigned int oaepBlockLen) const
00036 {
00037         assert (inputLength <= MaxUnpaddedLength(oaepBlockLen));
00038 
00039         // convert from bit length to byte length
00040         if (oaepBlockLen % 8 != 0)
00041         {
00042                 oaepBlock[0] = 0;
00043                 oaepBlock++;
00044         }
00045         oaepBlockLen /= 8;
00046 
00047         const unsigned int hLen = H::DIGESTSIZE;
00048         const unsigned int seedLen = hLen, dbLen = oaepBlockLen-seedLen;
00049         byte *const maskedSeed = oaepBlock;
00050         byte *const maskedDB = oaepBlock+seedLen;
00051 
00052         // DB = pHash || 00 ... || 01 || M
00053         memcpy(maskedDB, PHash<H,P,PLen>(), hLen);
00054         memset(maskedDB+hLen, 0, dbLen-hLen-inputLength-1);
00055         maskedDB[dbLen-inputLength-1] = 0x01;
00056         memcpy(maskedDB+dbLen-inputLength, input, inputLength);
00057 
00058         rng.GenerateBlock(maskedSeed, seedLen);
00059         H h;
00060         MGF mgf;
00061         mgf.GenerateAndMask(h, maskedDB, dbLen, maskedSeed, seedLen);
00062         mgf.GenerateAndMask(h, maskedSeed, seedLen, maskedDB, dbLen);
00063 }
00064 
00065 template <class H, class MGF, byte *P, unsigned int PLen>
00066 DecodingResult OAEP<H,MGF,P,PLen>::Unpad(const byte *oaepBlock, unsigned int oaepBlockLen, byte *output) const
00067 {
00068         bool invalid = false;
00069 
00070         // convert from bit length to byte length
00071         if (oaepBlockLen % 8 != 0)
00072         {
00073                 invalid = (oaepBlock[0] != 0) || invalid;
00074                 oaepBlock++;
00075         }
00076         oaepBlockLen /= 8;
00077 
00078         const unsigned int hLen = H::DIGESTSIZE;
00079         const unsigned int seedLen = hLen, dbLen = oaepBlockLen-seedLen;
00080 
00081         invalid = (oaepBlockLen < 2*hLen+1) || invalid;
00082 
00083         SecByteBlock t(oaepBlock, oaepBlockLen);
00084         byte *const maskedSeed = t;
00085         byte *const maskedDB = t+seedLen;
00086 
00087         H h;
00088         MGF mgf;
00089         mgf.GenerateAndMask(h, maskedSeed, seedLen, maskedDB, dbLen);
00090         mgf.GenerateAndMask(h, maskedDB, dbLen, maskedSeed, seedLen);
00091 
00092         // DB = pHash' || 00 ... || 01 || M
00093 
00094         byte *M = std::find(maskedDB+hLen, maskedDB+dbLen, 0x01);
00095         invalid = (M == maskedDB+dbLen) || invalid;
00096         invalid = (std::find_if(maskedDB+hLen, M, std::bind2nd(std::not_equal_to<byte>(), 0)) != M) || invalid;
00097         invalid = (memcmp(maskedDB, PHash<H,P,PLen>(), hLen) != 0) || invalid;
00098 
00099         if (invalid)
00100                 return DecodingResult();
00101 
00102         M++;
00103         memcpy(output, M, maskedDB+dbLen-M);
00104         return DecodingResult(maskedDB+dbLen-M);
00105 }
00106 
00107 NAMESPACE_END

Generated on Sun Mar 14 20:44:27 2004 for Crypto++ by doxygen 1.3.6-20040222