import polarssl 1.2.6's rsa
authorNIIBE Yutaka <gniibe@fsij.org>
Tue, 19 Mar 2013 04:19:55 +0000 (13:19 +0900)
committerNIIBE Yutaka <gniibe@fsij.org>
Tue, 19 Mar 2013 04:19:55 +0000 (13:19 +0900)
polarssl/include/polarssl/rsa.h
polarssl/library/rsa.c

index c0569b4..f9a0220 100644 (file)
@@ -1,6 +1,8 @@
 /**
  * \file rsa.h
  *
+ * \brief The RSA public-key cryptosystem
+ *
  *  Copyright (C) 2006-2010, Brainspark B.V.
  *
  *  This file is part of PolarSSL (http://www.polarssl.org)
 #ifndef POLARSSL_RSA_H
 #define POLARSSL_RSA_H
 
-#include "polarssl/bignum.h"
+#include "bignum.h"
 
 /*
  * RSA Error codes
  */
-#define POLARSSL_ERR_RSA_BAD_INPUT_DATA                    -0x0400
-#define POLARSSL_ERR_RSA_INVALID_PADDING                   -0x0410
-#define POLARSSL_ERR_RSA_KEY_GEN_FAILED                    -0x0420
-#define POLARSSL_ERR_RSA_KEY_CHECK_FAILED                  -0x0430
-#define POLARSSL_ERR_RSA_PUBLIC_FAILED                     -0x0440
-#define POLARSSL_ERR_RSA_PRIVATE_FAILED                    -0x0450
-#define POLARSSL_ERR_RSA_VERIFY_FAILED                     -0x0460
-#define POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE                  -0x0470
-#define POLARSSL_ERR_RSA_RNG_FAILED                        -0x0480
+#define POLARSSL_ERR_RSA_BAD_INPUT_DATA                    -0x4080  /**< Bad input parameters to function. */
+#define POLARSSL_ERR_RSA_INVALID_PADDING                   -0x4100  /**< Input data contains invalid padding and is rejected. */
+#define POLARSSL_ERR_RSA_KEY_GEN_FAILED                    -0x4180  /**< Something failed during generation of a key. */
+#define POLARSSL_ERR_RSA_KEY_CHECK_FAILED                  -0x4200  /**< Key failed to pass the libraries validity check. */
+#define POLARSSL_ERR_RSA_PUBLIC_FAILED                     -0x4280  /**< The public key operation failed. */
+#define POLARSSL_ERR_RSA_PRIVATE_FAILED                    -0x4300  /**< The private key operation failed. */
+#define POLARSSL_ERR_RSA_VERIFY_FAILED                     -0x4380  /**< The PKCS#1 verification failed. */
+#define POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE                  -0x4400  /**< The output buffer for decryption is not large enough. */
+#define POLARSSL_ERR_RSA_RNG_FAILED                        -0x4480  /**< The random generator failed to generate non-zeros. */
 
 /*
  * PKCS#1 constants
 #define SIG_RSA_MD2     2
 #define SIG_RSA_MD4     3
 #define SIG_RSA_MD5     4
-#define SIG_RSA_SHA1   5
-#define SIG_RSA_SHA224 14
-#define SIG_RSA_SHA256 11
-#define        SIG_RSA_SHA384  12
-#define SIG_RSA_SHA512 13
+#define SIG_RSA_SHA1    5
+#define SIG_RSA_SHA224 14
+#define SIG_RSA_SHA256 11
+#define SIG_RSA_SHA384 12
+#define SIG_RSA_SHA512 13
 
 #define RSA_PUBLIC      0
 #define RSA_PRIVATE     1
 #define RSA_SIGN        1
 #define RSA_CRYPT       2
 
-#define ASN1_STR_CONSTRUCTED_SEQUENCE  "\x30"
-#define ASN1_STR_NULL                          "\x05"
-#define ASN1_STR_OID                           "\x06"
-#define ASN1_STR_OCTET_STRING              "\x04"
+#define ASN1_STR_CONSTRUCTED_SEQUENCE   "\x30"
+#define ASN1_STR_NULL                   "\x05"
+#define ASN1_STR_OID                    "\x06"
+#define ASN1_STR_OCTET_STRING           "\x04"
 
-#define OID_DIGEST_ALG_MDX             "\x2A\x86\x48\x86\xF7\x0D\x02\x00"
-#define OID_HASH_ALG_SHA1              "\x2b\x0e\x03\x02\x1a"
-#define OID_HASH_ALG_SHA2X             "\x60\x86\x48\x01\x65\x03\x04\x02\x00"
+#define OID_DIGEST_ALG_MDX              "\x2A\x86\x48\x86\xF7\x0D\x02\x00"
+#define OID_HASH_ALG_SHA1               "\x2b\x0e\x03\x02\x1a"
+#define OID_HASH_ALG_SHA2X              "\x60\x86\x48\x01\x65\x03\x04\x02\x00"
 
-#define OID_ISO_MEMBER_BODIES      "\x2a"
-#define OID_ISO_IDENTIFIED_ORG     "\x2b"
+#define OID_ISO_MEMBER_BODIES           "\x2a"
+#define OID_ISO_IDENTIFIED_ORG          "\x2b"
 
 /*
  * ISO Member bodies OID parts
  */
-#define OID_COUNTRY_US                 "\x86\x48"
-#define OID_RSA_DATA_SECURITY      "\x86\xf7\x0d"
+#define OID_COUNTRY_US                  "\x86\x48"
+#define OID_RSA_DATA_SECURITY           "\x86\xf7\x0d"
 
 /*
  * ISO Identified organization OID parts
  */
-#define OID_OIW_SECSIG_SHA1            "\x0e\x03\x02\x1a"
+#define OID_OIW_SECSIG_SHA1             "\x0e\x03\x02\x1a"
 
 /*
  * DigestInfo ::= SEQUENCE {
  *
  * Digest ::= OCTET STRING
  */
-#define ASN1_HASH_MDX                                          \
-(                                                                          \
-    ASN1_STR_CONSTRUCTED_SEQUENCE "\x20"               \
-      ASN1_STR_CONSTRUCTED_SEQUENCE "\x0C"             \
-        ASN1_STR_OID "\x08"                                    \
-         OID_DIGEST_ALG_MDX                                    \
-       ASN1_STR_NULL "\x00"                                    \
-      ASN1_STR_OCTET_STRING "\x10"                         \
+#define ASN1_HASH_MDX                           \
+(                                               \
+    ASN1_STR_CONSTRUCTED_SEQUENCE "\x20"        \
+      ASN1_STR_CONSTRUCTED_SEQUENCE "\x0C"      \
+        ASN1_STR_OID "\x08"                     \
+      OID_DIGEST_ALG_MDX                        \
+    ASN1_STR_NULL "\x00"                        \
+      ASN1_STR_OCTET_STRING "\x10"              \
 )
 
-#define ASN1_HASH_SHA1                                         \
-    ASN1_STR_CONSTRUCTED_SEQUENCE "\x21"               \
-      ASN1_STR_CONSTRUCTED_SEQUENCE "\x09"             \
-        ASN1_STR_OID "\x05"                                    \
-         OID_HASH_ALG_SHA1                                         \
-        ASN1_STR_NULL "\x00"                               \
+#define ASN1_HASH_SHA1                          \
+    ASN1_STR_CONSTRUCTED_SEQUENCE "\x21"        \
+      ASN1_STR_CONSTRUCTED_SEQUENCE "\x09"      \
+        ASN1_STR_OID "\x05"                     \
+      OID_HASH_ALG_SHA1                         \
+        ASN1_STR_NULL "\x00"                    \
+      ASN1_STR_OCTET_STRING "\x14"
+
+#define ASN1_HASH_SHA1_ALT                      \
+    ASN1_STR_CONSTRUCTED_SEQUENCE "\x1F"        \
+      ASN1_STR_CONSTRUCTED_SEQUENCE "\x07"      \
+        ASN1_STR_OID "\x05"                     \
+      OID_HASH_ALG_SHA1                         \
       ASN1_STR_OCTET_STRING "\x14"
 
-#define ASN1_HASH_SHA2X                                                \
-    ASN1_STR_CONSTRUCTED_SEQUENCE "\x11"               \
-      ASN1_STR_CONSTRUCTED_SEQUENCE "\x0d"             \
-        ASN1_STR_OID "\x09"                                    \
-         OID_HASH_ALG_SHA2X                                    \
-        ASN1_STR_NULL "\x00"                               \
+#define ASN1_HASH_SHA2X                         \
+    ASN1_STR_CONSTRUCTED_SEQUENCE "\x11"        \
+      ASN1_STR_CONSTRUCTED_SEQUENCE "\x0d"      \
+        ASN1_STR_OID "\x09"                     \
+      OID_HASH_ALG_SHA2X                        \
+        ASN1_STR_NULL "\x00"                    \
       ASN1_STR_OCTET_STRING "\x00"
 
 /**
 typedef struct
 {
     int ver;                    /*!<  always 0          */
-    int len;                    /*!<  size(N) in chars  */
+    size_t len;                 /*!<  size(N) in chars  */
 
     mpi N;                      /*!<  public modulus    */
     mpi E;                      /*!<  public exponent   */
@@ -142,8 +151,12 @@ typedef struct
     mpi RP;                     /*!<  cached R^2 mod P  */
     mpi RQ;                     /*!<  cached R^2 mod Q  */
 
-    int padding;                /*!<  1.5 or OAEP/PSS   */
-    int hash_id;                /*!<  hash identifier   */
+    int padding;                /*!<  RSA_PKCS_V15 for 1.5 padding and
+                                      RSA_PKCS_v21 for OAEP/PSS         */
+    int hash_id;                /*!<  Hash identifier of md_type_t as
+                                      specified in the md.h header file
+                                      for the EME-OAEP and EMSA-PSS
+                                      encoding                          */
 }
 rsa_context;
 
@@ -154,15 +167,15 @@ extern "C" {
 /**
  * \brief          Initialize an RSA context
  *
+ *                 Note: Set padding to RSA_PKCS_V21 for the RSAES-OAEP
+ *                 encryption scheme and the RSASSA-PSS signature scheme.
+ *
  * \param ctx      RSA context to be initialized
  * \param padding  RSA_PKCS_V15 or RSA_PKCS_V21
  * \param hash_id  RSA_PKCS_V21 hash identifier
  *
  * \note           The hash_id parameter is actually ignored
  *                 when using RSA_PKCS_V15 padding.
- *
- * \note           Currently, RSA_PKCS_V21 padding
- *                 is not supported.
  */
 void rsa_init( rsa_context *ctx,
                int padding,
@@ -183,9 +196,9 @@ void rsa_init( rsa_context *ctx,
  * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
  */
 int rsa_gen_key( rsa_context *ctx,
-                 unsigned char (*f_rng)(void *),
+                 int (*f_rng)(void *, unsigned char *, size_t),
                  void *p_rng,
-                 int nbits, int exponent );
+                 unsigned int nbits, int exponent );
 
 /**
  * \brief          Check a public RSA key
@@ -242,10 +255,12 @@ int rsa_private( rsa_context *ctx,
                  unsigned char *output );
 
 /**
- * \brief          Add the message padding, then do an RSA operation
+ * \brief          Generic wrapper to perform a PKCS#1 encryption using the
+ *                 mode from the context. Add the message padding, then do an
+ *                 RSA operation.
  *
  * \param ctx      RSA context
- * \param f_rng    RNG function
+ * \param f_rng    RNG function (Needed for padding and PKCS#1 v2.1 encoding)
  * \param p_rng    RNG parameter
  * \param mode     RSA_PUBLIC or RSA_PRIVATE
  * \param ilen     contains the plaintext length
@@ -258,21 +273,73 @@ int rsa_private( rsa_context *ctx,
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
  */
 int rsa_pkcs1_encrypt( rsa_context *ctx,
-                       unsigned char (*f_rng)(void *),
+                       int (*f_rng)(void *, unsigned char *, size_t),
                        void *p_rng,
-                       int mode, int  ilen,
+                       int mode, size_t ilen,
                        const unsigned char *input,
                        unsigned char *output );
 
 /**
- * \brief          Do an RSA operation, then remove the message padding
+ * \brief          Perform a PKCS#1 v1.5 encryption (RSAES-PKCS1-v1_5-ENCRYPT)
+ *
+ * \param ctx      RSA context
+ * \param f_rng    RNG function (Needed for padding)
+ * \param p_rng    RNG parameter
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param ilen     contains the plaintext length
+ * \param input    buffer holding the data to be encrypted
+ * \param output   buffer that will hold the ciphertext
+ *
+ * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The output buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ */
+int rsa_rsaes_pkcs1_v15_encrypt( rsa_context *ctx,
+                                 int (*f_rng)(void *, unsigned char *, size_t),
+                                 void *p_rng,
+                                 int mode, size_t ilen,
+                                 const unsigned char *input,
+                                 unsigned char *output );
+
+/**
+ * \brief          Perform a PKCS#1 v2.1 OAEP encryption (RSAES-OAEP-ENCRYPT)
+ *
+ * \param ctx      RSA context
+ * \param f_rng    RNG function (Needed for padding and PKCS#1 v2.1 encoding)
+ * \param p_rng    RNG parameter
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param label    buffer holding the custom label to use
+ * \param label_len contains the label length
+ * \param ilen     contains the plaintext length
+ * \param input    buffer holding the data to be encrypted
+ * \param output   buffer that will hold the ciphertext
+ *
+ * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The output buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ */
+int rsa_rsaes_oaep_encrypt( rsa_context *ctx,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng,
+                            int mode,
+                            const unsigned char *label, size_t label_len,
+                            size_t ilen,
+                            const unsigned char *input,
+                            unsigned char *output );
+
+/**
+ * \brief          Generic wrapper to perform a PKCS#1 decryption using the
+ *                 mode from the context. Do an RSA operation, then remove
+ *                 the message padding
  *
  * \param ctx      RSA context
  * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param olen     will contain the plaintext length
  * \param input    buffer holding the encrypted data
  * \param output   buffer that will hold the plaintext
- * \param olen     will contain the plaintext length
- * \param output_max_len       maximum length of the output buffer
+ * \param output_max_len    maximum length of the output buffer
  *
  * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
  *
@@ -281,16 +348,68 @@ int rsa_pkcs1_encrypt( rsa_context *ctx,
  *                 an error is thrown.
  */
 int rsa_pkcs1_decrypt( rsa_context *ctx,
-                       int mode, int *olen,
+                       int mode, size_t *olen,
                        const unsigned char *input,
                        unsigned char *output,
-                              int output_max_len );
+                       size_t output_max_len );
+
+/**
+ * \brief          Perform a PKCS#1 v1.5 decryption (RSAES-PKCS1-v1_5-DECRYPT)
+ *
+ * \param ctx      RSA context
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param olen     will contain the plaintext length
+ * \param input    buffer holding the encrypted data
+ * \param output   buffer that will hold the plaintext
+ * \param output_max_len    maximum length of the output buffer
+ *
+ * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The output buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used) otherwise
+ *                 an error is thrown.
+ */
+int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx,
+                                 int mode, size_t *olen,
+                                 const unsigned char *input,
+                                 unsigned char *output,
+                                 size_t output_max_len );
 
 /**
- * \brief          Do a private RSA to sign a message digest
+ * \brief          Perform a PKCS#1 v2.1 OAEP decryption (RSAES-OAEP-DECRYPT)
  *
  * \param ctx      RSA context
  * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param label    buffer holding the custom label to use
+ * \param label_len contains the label length
+ * \param olen     will contain the plaintext length
+ * \param input    buffer holding the encrypted data
+ * \param output   buffer that will hold the plaintext
+ * \param output_max_len    maximum length of the output buffer
+ *
+ * \return         0 if successful, or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The output buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used) otherwise
+ *                 an error is thrown.
+ */
+int rsa_rsaes_oaep_decrypt( rsa_context *ctx,
+                            int mode,
+                            const unsigned char *label, size_t label_len,
+                            size_t *olen,
+                            const unsigned char *input,
+                            unsigned char *output,
+                            size_t output_max_len );
+
+/**
+ * \brief          Generic wrapper to perform a PKCS#1 signature using the
+ *                 mode from the context. Do a private RSA operation to sign
+ *                 a message digest
+ *
+ * \param ctx      RSA context
+ * \param f_rng    RNG function (Needed for PKCS#1 v2.1 encoding)
+ * \param p_rng    RNG parameter
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
  * \param hash_id  SIG_RSA_RAW, SIG_RSA_MD{2,4,5} or SIG_RSA_SHA{1,224,256,384,512}
  * \param hashlen  message digest length (for SIG_RSA_RAW only)
  * \param hash     buffer holding the message digest
@@ -301,16 +420,82 @@ int rsa_pkcs1_decrypt( rsa_context *ctx,
  *
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ *
+ * \note           In case of PKCS#1 v2.1 encoding keep in mind that
+ *                 the hash_id in the RSA context is the one used for the
+ *                 encoding. hash_id in the function call is the type of hash
+ *                 that is encoded. According to RFC 3447 it is advised to
+ *                 keep both hashes the same.
  */
 int rsa_pkcs1_sign( rsa_context *ctx,
+                    int (*f_rng)(void *, unsigned char *, size_t),
+                    void *p_rng,
                     int mode,
                     int hash_id,
-                    int hashlen,
+                    unsigned int hashlen,
                     const unsigned char *hash,
                     unsigned char *sig );
 
 /**
- * \brief          Do a public RSA and check the message digest
+ * \brief          Perform a PKCS#1 v1.5 signature (RSASSA-PKCS1-v1_5-SIGN)
+ *
+ * \param ctx      RSA context
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param hash_id  SIG_RSA_RAW, SIG_RSA_MD{2,4,5} or SIG_RSA_SHA{1,224,256,384,512}
+ * \param hashlen  message digest length (for SIG_RSA_RAW only)
+ * \param hash     buffer holding the message digest
+ * \param sig      buffer that will hold the ciphertext
+ *
+ * \return         0 if the signing operation was successful,
+ *                 or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The "sig" buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ */
+int rsa_rsassa_pkcs1_v15_sign( rsa_context *ctx,
+                               int mode,
+                               int hash_id,
+                               unsigned int hashlen,
+                               const unsigned char *hash,
+                               unsigned char *sig );
+
+/**
+ * \brief          Perform a PKCS#1 v2.1 PSS signature (RSASSA-PSS-SIGN)
+ *
+ * \param ctx      RSA context
+ * \param f_rng    RNG function (Needed for PKCS#1 v2.1 encoding)
+ * \param p_rng    RNG parameter
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param hash_id  SIG_RSA_RAW, SIG_RSA_MD{2,4,5} or SIG_RSA_SHA{1,224,256,384,512}
+ * \param hashlen  message digest length (for SIG_RSA_RAW only)
+ * \param hash     buffer holding the message digest
+ * \param sig      buffer that will hold the ciphertext
+ *
+ * \return         0 if the signing operation was successful,
+ *                 or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The "sig" buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ *
+ * \note           In case of PKCS#1 v2.1 encoding keep in mind that
+ *                 the hash_id in the RSA context is the one used for the
+ *                 encoding. hash_id in the function call is the type of hash
+ *                 that is encoded. According to RFC 3447 it is advised to
+ *                 keep both hashes the same.
+ */
+int rsa_rsassa_pss_sign( rsa_context *ctx,
+                         int (*f_rng)(void *, unsigned char *, size_t),
+                         void *p_rng,
+                         int mode,
+                         int hash_id,
+                         unsigned int hashlen,
+                         const unsigned char *hash,
+                         unsigned char *sig );
+
+/**
+ * \brief          Generic wrapper to perform a PKCS#1 verification using the
+ *                 mode from the context. Do a public RSA operation and check
+ *                 the message digest
  *
  * \param ctx      points to an RSA public key
  * \param mode     RSA_PUBLIC or RSA_PRIVATE
@@ -324,13 +509,72 @@ int rsa_pkcs1_sign( rsa_context *ctx,
  *
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ *
+ * \note           In case of PKCS#1 v2.1 encoding keep in mind that
+ *                 the hash_id in the RSA context is the one used for the
+ *                 verification. hash_id in the function call is the type of hash
+ *                 that is verified. According to RFC 3447 it is advised to
+ *                 keep both hashes the same.
  */
 int rsa_pkcs1_verify( rsa_context *ctx,
                       int mode,
                       int hash_id,
-                      int hashlen,
+                      unsigned int hashlen,
                       const unsigned char *hash,
-                      const unsigned char *sig );
+                      unsigned char *sig );
+
+/**
+ * \brief          Perform a PKCS#1 v1.5 verification (RSASSA-PKCS1-v1_5-VERIFY)
+ *
+ * \param ctx      points to an RSA public key
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param hash_id  SIG_RSA_RAW, SIG_RSA_MD{2,4,5} or SIG_RSA_SHA{1,224,256,384,512}
+ * \param hashlen  message digest length (for SIG_RSA_RAW only)
+ * \param hash     buffer holding the message digest
+ * \param sig      buffer holding the ciphertext
+ *
+ * \return         0 if the verify operation was successful,
+ *                 or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The "sig" buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ */
+int rsa_rsassa_pkcs1_v15_verify( rsa_context *ctx,
+                                 int mode,
+                                 int hash_id,
+                                 unsigned int hashlen,
+                                 const unsigned char *hash,
+                                 unsigned char *sig );
+
+/**
+ * \brief          Perform a PKCS#1 v2.1 PSS verification (RSASSA-PSS-VERIFY)
+ * \brief          Do a public RSA and check the message digest
+ *
+ * \param ctx      points to an RSA public key
+ * \param mode     RSA_PUBLIC or RSA_PRIVATE
+ * \param hash_id  SIG_RSA_RAW, SIG_RSA_MD{2,4,5} or SIG_RSA_SHA{1,224,256,384,512}
+ * \param hashlen  message digest length (for SIG_RSA_RAW only)
+ * \param hash     buffer holding the message digest
+ * \param sig      buffer holding the ciphertext
+ *
+ * \return         0 if the verify operation was successful,
+ *                 or an POLARSSL_ERR_RSA_XXX error code
+ *
+ * \note           The "sig" buffer must be as large as the size
+ *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
+ *
+ * \note           In case of PKCS#1 v2.1 encoding keep in mind that
+ *                 the hash_id in the RSA context is the one used for the
+ *                 verification. hash_id in the function call is the type of hash
+ *                 that is verified. According to RFC 3447 it is advised to
+ *                 keep both hashes the same.
+ */
+int rsa_rsassa_pss_verify( rsa_context *ctx,
+                           int mode,
+                           int hash_id,
+                           unsigned int hashlen,
+                           const unsigned char *hash,
+                           unsigned char *sig );
 
 /**
  * \brief          Free the components of an RSA key
index e135e57..e53d9a2 100644 (file)
@@ -1,7 +1,7 @@
 /*
  *  The RSA public-key cryptosystem
  *
- *  Copyright (C) 2006-2010, Brainspark B.V.
+ *  Copyright (C) 2006-2011, Brainspark B.V.
  *
  *  This file is part of PolarSSL (http://www.polarssl.org)
  *  Lead Maintainer: Paul Bakker <polarssl_maintainer at polarssl.org>
 
 #include "polarssl/rsa.h"
 
+#if defined(POLARSSL_PKCS1_V21)
+#include "polarssl/md.h"
+#endif
+
 #include <stdlib.h>
-#include <string.h>
 #include <stdio.h>
 
 /*
@@ -58,9 +61,9 @@ void rsa_init( rsa_context *ctx,
  * Generate an RSA keypair
  */
 int rsa_gen_key( rsa_context *ctx,
-        unsigned char (*f_rng)(void *),
-        void *p_rng,
-        int nbits, int exponent )
+                 int (*f_rng)(void *, unsigned char *, size_t),
+                 void *p_rng,
+                 unsigned int nbits, int exponent )
 {
     int ret;
     mpi P1, Q1, H, G;
@@ -68,7 +71,7 @@ int rsa_gen_key( rsa_context *ctx,
     if( f_rng == NULL || nbits < 128 || exponent < 3 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
-    mpi_init( &P1, &Q1, &H, &G, NULL );
+    mpi_init( &P1 ); mpi_init( &Q1 ); mpi_init( &H ); mpi_init( &G );
 
     /*
      * find primes P and Q with Q < P so that:
@@ -101,7 +104,6 @@ int rsa_gen_key( rsa_context *ctx,
     }
     while( mpi_cmp_int( &G, 1 ) != 0 );
 
-#if 0
     /*
      * D  = E^-1 mod ((P-1)*(Q-1))
      * DP = D mod (P - 1)
@@ -112,18 +114,17 @@ int rsa_gen_key( rsa_context *ctx,
     MPI_CHK( mpi_mod_mpi( &ctx->DP, &ctx->D, &P1 ) );
     MPI_CHK( mpi_mod_mpi( &ctx->DQ, &ctx->D, &Q1 ) );
     MPI_CHK( mpi_inv_mod( &ctx->QP, &ctx->Q, &ctx->P ) );
-#endif
 
     ctx->len = ( mpi_msb( &ctx->N ) + 7 ) >> 3;
 
 cleanup:
 
-    mpi_free( &G, &H, &Q1, &P1, NULL );
+    mpi_free( &P1 ); mpi_free( &Q1 ); mpi_free( &H ); mpi_free( &G );
 
     if( ret != 0 )
     {
         rsa_free( ctx );
-        return( POLARSSL_ERR_RSA_KEY_GEN_FAILED | ret );
+        return( POLARSSL_ERR_RSA_KEY_GEN_FAILED + ret );
     }
 
     return( 0 );   
@@ -131,7 +132,6 @@ cleanup:
 
 #endif
 
-#if 0
 /*
  * Check a public RSA key
  */
@@ -145,7 +145,7 @@ int rsa_check_pubkey( const rsa_context *ctx )
         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
 
     if( mpi_msb( &ctx->N ) < 128 ||
-        mpi_msb( &ctx->N ) > 4096 )
+        mpi_msb( &ctx->N ) > POLARSSL_MPI_MAX_BITS )
         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
 
     if( mpi_msb( &ctx->E ) < 2 ||
@@ -161,7 +161,7 @@ int rsa_check_pubkey( const rsa_context *ctx )
 int rsa_check_privkey( const rsa_context *ctx )
 {
     int ret;
-    mpi PQ, DE, P1, Q1, H, I, G, G2, L1, L2;
+    mpi PQ, DE, P1, Q1, H, I, G, G2, L1, L2, DP, DQ, QP;
 
     if( ( ret = rsa_check_pubkey( ctx ) ) != 0 )
         return( ret );
@@ -169,7 +169,10 @@ int rsa_check_privkey( const rsa_context *ctx )
     if( !ctx->P.p || !ctx->Q.p || !ctx->D.p )
         return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED );
 
-    mpi_init( &PQ, &DE, &P1, &Q1, &H, &I, &G, &G2, &L1, &L2, NULL );
+    mpi_init( &PQ ); mpi_init( &DE ); mpi_init( &P1 ); mpi_init( &Q1 );
+    mpi_init( &H  ); mpi_init( &I  ); mpi_init( &G  ); mpi_init( &G2 );
+    mpi_init( &L1 ); mpi_init( &L2 ); mpi_init( &DP ); mpi_init( &DQ );
+    mpi_init( &QP );
 
     MPI_CHK( mpi_mul_mpi( &PQ, &ctx->P, &ctx->Q ) );
     MPI_CHK( mpi_mul_mpi( &DE, &ctx->D, &ctx->E ) );
@@ -182,25 +185,37 @@ int rsa_check_privkey( const rsa_context *ctx )
     MPI_CHK( mpi_div_mpi( &L1, &L2, &H, &G2 ) );  
     MPI_CHK( mpi_mod_mpi( &I, &DE, &L1  ) );
 
+    MPI_CHK( mpi_mod_mpi( &DP, &ctx->D, &P1 ) );
+    MPI_CHK( mpi_mod_mpi( &DQ, &ctx->D, &Q1 ) );
+    MPI_CHK( mpi_inv_mod( &QP, &ctx->Q, &ctx->P ) );
     /*
      * Check for a valid PKCS1v2 private key
      */
-    if( mpi_cmp_mpi( &PQ, &ctx->N ) == 0 &&
-        mpi_cmp_int( &L2, 0 ) == 0 &&
-        mpi_cmp_int( &I, 1 ) == 0 &&
-        mpi_cmp_int( &G, 1 ) == 0 )
+    if( mpi_cmp_mpi( &PQ, &ctx->N ) != 0 ||
+        mpi_cmp_mpi( &DP, &ctx->DP ) != 0 ||
+        mpi_cmp_mpi( &DQ, &ctx->DQ ) != 0 ||
+        mpi_cmp_mpi( &QP, &ctx->QP ) != 0 ||
+        mpi_cmp_int( &L2, 0 ) != 0 ||
+        mpi_cmp_int( &I, 1 ) != 0 ||
+        mpi_cmp_int( &G, 1 ) != 0 )
     {
-        mpi_free( &G, &I, &H, &Q1, &P1, &DE, &PQ, &G2, &L1, &L2, NULL );
-        return( 0 );
+        ret = POLARSSL_ERR_RSA_KEY_CHECK_FAILED;
     }
-
     
 cleanup:
+    mpi_free( &PQ ); mpi_free( &DE ); mpi_free( &P1 ); mpi_free( &Q1 );
+    mpi_free( &H  ); mpi_free( &I  ); mpi_free( &G  ); mpi_free( &G2 );
+    mpi_free( &L1 ); mpi_free( &L2 ); mpi_free( &DP ); mpi_free( &DQ );
+    mpi_free( &QP );
+
+    if( ret == POLARSSL_ERR_RSA_KEY_CHECK_FAILED )
+        return( ret );
 
-    mpi_free( &G, &I, &H, &Q1, &P1, &DE, &PQ, &G2, &L1, &L2, NULL );
-    return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED | ret );
+    if( ret != 0 )
+        return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED + ret );
+
+    return( 0 );
 }
-#endif
 
 /*
  * Do an RSA public key operation
@@ -209,16 +224,17 @@ int rsa_public( rsa_context *ctx,
                 const unsigned char *input,
                 unsigned char *output )
 {
-    int ret, olen;
+    int ret;
+    size_t olen;
     mpi T;
 
-    mpi_init( &T, NULL );
+    mpi_init( &T );
 
     MPI_CHK( mpi_read_binary( &T, input, ctx->len ) );
 
     if( mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
     {
-        mpi_free( &T, NULL );
+        mpi_free( &T );
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
     }
 
@@ -228,10 +244,10 @@ int rsa_public( rsa_context *ctx,
 
 cleanup:
 
-    mpi_free( &T, NULL );
+    mpi_free( &T );
 
     if( ret != 0 )
-        return( POLARSSL_ERR_RSA_PUBLIC_FAILED | ret );
+        return( POLARSSL_ERR_RSA_PUBLIC_FAILED + ret );
 
     return( 0 );
 }
@@ -243,20 +259,21 @@ int rsa_private( rsa_context *ctx,
                  const unsigned char *input,
                  unsigned char *output )
 {
-    int ret, olen;
+    int ret;
+    size_t olen;
     mpi T, T1, T2;
 
-    mpi_init( &T, &T1, &T2, NULL );
+    mpi_init( &T ); mpi_init( &T1 ); mpi_init( &T2 );
 
     MPI_CHK( mpi_read_binary( &T, input, ctx->len ) );
-#if 0
+
     if( mpi_cmp_mpi( &T, &ctx->N ) >= 0 )
     {
-        mpi_free( &T, NULL );
+        mpi_free( &T );
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
     }
-#endif
-#if 0
+
+#if defined(POLARSSL_RSA_NO_CRT)
     MPI_CHK( mpi_exp_mod( &T, &T, &ctx->D, &ctx->N, &ctx->RN ) );
 #else
     /*
@@ -287,86 +304,251 @@ int rsa_private( rsa_context *ctx,
 
 cleanup:
 
-    mpi_free( &T, &T1, &T2, NULL );
+    mpi_free( &T ); mpi_free( &T1 ); mpi_free( &T2 );
 
     if( ret != 0 )
-        return( POLARSSL_ERR_RSA_PRIVATE_FAILED | ret );
+        return( POLARSSL_ERR_RSA_PRIVATE_FAILED + ret );
 
     return( 0 );
 }
 
+#if defined(POLARSSL_PKCS1_V21)
+/**
+ * Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
+ *
+ * \param dst       buffer to mask
+ * \param dlen      length of destination buffer
+ * \param src       source of the mask generation
+ * \param slen      length of the source buffer
+ * \param md_ctx    message digest context to use
+ */
+static void mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src, size_t slen,  
+                       md_context_t *md_ctx )
+{
+    unsigned char mask[POLARSSL_MD_MAX_SIZE];
+    unsigned char counter[4];
+    unsigned char *p;
+    unsigned int hlen;
+    size_t i, use_len;
+
+    memset( mask, 0, POLARSSL_MD_MAX_SIZE );
+    memset( counter, 0, 4 );
+
+    hlen = md_ctx->md_info->size;
+
+    // Generate and apply dbMask
+    //
+    p = dst;
+
+    while( dlen > 0 )
+    {
+        use_len = hlen;
+        if( dlen < hlen )
+            use_len = dlen;
+
+        md_starts( md_ctx );
+        md_update( md_ctx, src, slen );
+        md_update( md_ctx, counter, 4 );
+        md_finish( md_ctx, mask );
+
+        for( i = 0; i < use_len; ++i )
+            *p++ ^= mask[i];
+
+        counter[3]++;
+
+        dlen -= use_len;
+    }
+}
+#endif
+
+#if defined(POLARSSL_PKCS1_V21)
 /*
- * Add the message padding, then do an RSA operation
+ * Implementation of the PKCS#1 v2.1 RSAES-OAEP-ENCRYPT function
  */
-int rsa_pkcs1_encrypt( rsa_context *ctx,
-                       unsigned char (*f_rng)(void *),
-                       void *p_rng,
-                       int mode, int  ilen,
-                       const unsigned char *input,
-                       unsigned char *output )
+int rsa_rsaes_oaep_encrypt( rsa_context *ctx,
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng,
+                            int mode,
+                            const unsigned char *label, size_t label_len,
+                            size_t ilen,
+                            const unsigned char *input,
+                            unsigned char *output )
 {
-    int nb_pad, olen;
+    size_t olen;
+    int ret;
     unsigned char *p = output;
+    unsigned int hlen;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    md_info = md_info_from_type( ctx->hash_id );
+
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
+    hlen = md_get_size( md_info );
 
-    switch( ctx->padding )
-    {
-        case RSA_PKCS_V15:
+    if( olen < ilen + 2 * hlen + 2 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
-            if( ilen < 0 || olen < ilen + 11 || f_rng == NULL )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    memset( output, 0, olen );
 
-            nb_pad = olen - 3 - ilen;
+    *p++ = 0;
 
-            *p++ = 0;
-            *p++ = RSA_CRYPT;
+    // Generate a random octet string seed
+    //
+    if( ( ret = f_rng( p_rng, p, hlen ) ) != 0 )
+        return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
 
-            while( nb_pad-- > 0 )
-            {
-                int rng_dl = 100;
+    p += hlen;
 
-                do {
-                    *p = f_rng( p_rng );
-                } while( *p == 0 && --rng_dl );
+    // Construct DB
+    //
+    md( md_info, label, label_len, p );
+    p += hlen;
+    p += olen - 2 * hlen - 2 - ilen;
+    *p++ = 1;
+    memcpy( p, input, ilen );
 
-                // Check if RNG failed to generate data
-                //
-                if( rng_dl == 0 )
-                    return POLARSSL_ERR_RSA_RNG_FAILED;
+    md_init_ctx( &md_ctx, md_info );
 
-                p++;
-            }
-            *p++ = 0;
-            memcpy( p, input, ilen );
-            break;
+    // maskedDB: Apply dbMask to DB
+    //
+    mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
+               &md_ctx );
 
-        default:
+    // maskedSeed: Apply seedMask to seed
+    //
+    mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
+               &md_ctx );
 
-            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    md_free_ctx( &md_ctx );
+
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, output, output )
+            : rsa_private( ctx, output, output ) );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-ENCRYPT function
+ */
+int rsa_rsaes_pkcs1_v15_encrypt( rsa_context *ctx,
+                                 int (*f_rng)(void *, unsigned char *, size_t),
+                                 void *p_rng,
+                                 int mode, size_t ilen,
+                                 const unsigned char *input,
+                                 unsigned char *output )
+{
+    size_t nb_pad, olen;
+    int ret;
+    unsigned char *p = output;
+
+    if( ctx->padding != RSA_PKCS_V15 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    olen = ctx->len;
+
+    if( olen < ilen + 11 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    nb_pad = olen - 3 - ilen;
+
+    *p++ = 0;
+    if( mode == RSA_PUBLIC )
+    {
+        *p++ = RSA_CRYPT;
+
+        while( nb_pad-- > 0 )
+        {
+            int rng_dl = 100;
+
+            do {
+                ret = f_rng( p_rng, p, 1 );
+            } while( *p == 0 && --rng_dl && ret == 0 );
+
+            // Check if RNG failed to generate data
+            //
+            if( rng_dl == 0 || ret != 0)
+                return POLARSSL_ERR_RSA_RNG_FAILED + ret;
+
+            p++;
+        }
+    }
+    else
+    {
+        *p++ = RSA_SIGN;
+
+        while( nb_pad-- > 0 )
+            *p++ = 0xFF;
     }
 
+    *p++ = 0;
+    memcpy( p, input, ilen );
+
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, output, output )
             : rsa_private( ctx, output, output ) );
 }
 
 /*
- * Do an RSA operation, then remove the message padding
+ * Add the message padding, then do an RSA operation
  */
-int rsa_pkcs1_decrypt( rsa_context *ctx,
-                       int mode, int *olen,
+int rsa_pkcs1_encrypt( rsa_context *ctx,
+                       int (*f_rng)(void *, unsigned char *, size_t),
+                       void *p_rng,
+                       int mode, size_t ilen,
                        const unsigned char *input,
-                       unsigned char *output,
-                       int output_max_len)
+                       unsigned char *output )
+{
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsaes_pkcs1_v15_encrypt( ctx, f_rng, p_rng, mode, ilen,
+                                                input, output );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsaes_oaep_encrypt( ctx, f_rng, p_rng, mode, NULL, 0,
+                                           ilen, input, output );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
+#if defined(POLARSSL_PKCS1_V21)
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-OAEP-DECRYPT function
+ */
+int rsa_rsaes_oaep_decrypt( rsa_context *ctx,
+                            int mode, 
+                            const unsigned char *label, size_t label_len,
+                            size_t *olen,
+                            const unsigned char *input,
+                            unsigned char *output,
+                            size_t output_max_len )
 {
-    int ret, ilen;
+    int ret;
+    size_t ilen;
     unsigned char *p;
-    unsigned char buf[256];
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+    unsigned char lhash[POLARSSL_MD_MAX_SIZE];
+    unsigned int hlen;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ilen = ctx->len;
 
-    if( ilen < 16 || ilen > (int) sizeof( buf ) )
+    if( ilen < 16 || ilen > sizeof( buf ) )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ret = ( mode == RSA_PUBLIC )
@@ -378,107 +560,365 @@ int rsa_pkcs1_decrypt( rsa_context *ctx,
 
     p = buf;
 
-    switch( ctx->padding )
+    if( *p++ != 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    hlen = md_get_size( md_info );
+
+    md_init_ctx( &md_ctx, md_info );
+
+    // Generate lHash
+    //
+    md( md_info, label, label_len, lhash );
+
+    // seed: Apply seedMask to maskedSeed
+    //
+    mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
+               &md_ctx );
+
+    // DB: Apply dbMask to maskedDB
+    //
+    mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
+               &md_ctx );
+
+    p += hlen;
+    md_free_ctx( &md_ctx );
+
+    // Check validity
+    //
+    if( memcmp( lhash, p, hlen ) != 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    p += hlen;
+
+    while( *p == 0 && p < buf + ilen )
+        p++;
+
+    if( p == buf + ilen )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    if( *p++ != 0x01 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    if (ilen - (p - buf) > output_max_len)
+        return( POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE );
+
+    *olen = ilen - (p - buf);
+    memcpy( output, p, *olen );
+
+    return( 0 );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-DECRYPT function
+ */
+int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx,
+                                 int mode, size_t *olen,
+                                 const unsigned char *input,
+                                 unsigned char *output,
+                                 size_t output_max_len)
+{
+    int ret, correct = 1;
+    size_t ilen, pad_count = 0;
+    unsigned char *p, *q;
+    unsigned char bt;
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ilen = ctx->len;
+
+    if( ilen < 16 || ilen > sizeof( buf ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ret = ( mode == RSA_PUBLIC )
+          ? rsa_public(  ctx, input, buf )
+          : rsa_private( ctx, input, buf );
+
+    if( ret != 0 )
+        return( ret );
+
+    p = buf;
+
+    if( *p++ != 0 )
+        correct = 0;
+
+    bt = *p++;
+    if( ( bt != RSA_CRYPT && mode == RSA_PRIVATE ) ||
+        ( bt != RSA_SIGN && mode == RSA_PUBLIC ) )
     {
-        case RSA_PKCS_V15:
+        correct = 0;
+    }
 
-            if( *p++ != 0 || *p++ != RSA_CRYPT )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    if( bt == RSA_CRYPT )
+    {
+        while( *p != 0 && p < buf + ilen - 1 )
+            pad_count += ( *p++ != 0 );
 
-            while( *p != 0 )
-            {
-                if( p >= buf + ilen - 1 )
-                    return( POLARSSL_ERR_RSA_INVALID_PADDING );
-                p++;
-            }
-            p++;
-            break;
+        correct &= ( *p == 0 && p < buf + ilen - 1 );
 
-        default:
+        q = p;
 
-            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        // Also pass over all other bytes to reduce timing differences
+        //
+        while ( q < buf + ilen - 1 )
+            pad_count += ( *q++ != 0 );
+
+        // Prevent compiler optimization of pad_count
+        //
+        correct |= pad_count & 0x100000; /* Always 0 unless 1M bit keys */
+        p++;
     }
+    else
+    {
+        while( *p == 0xFF && p < buf + ilen - 1 )
+            pad_count += ( *p++ == 0xFF );
+
+        correct &= ( *p == 0 && p < buf + ilen - 1 );
+
+        q = p;
+
+        // Also pass over all other bytes to reduce timing differences
+        //
+        while ( q < buf + ilen - 1 )
+            pad_count += ( *q++ != 0 );
+
+        // Prevent compiler optimization of pad_count
+        //
+        correct |= pad_count & 0x100000; /* Always 0 unless 1M bit keys */
+        p++;
+    }
+
+    if( correct == 0 )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
 
-    if (ilen - (int)(p - buf) > output_max_len)
-       return( POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE );
+    if (ilen - (p - buf) > output_max_len)
+        return( POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE );
 
-    *olen = ilen - (int)(p - buf);
+    *olen = ilen - (p - buf);
     memcpy( output, p, *olen );
 
     return( 0 );
 }
 
 /*
- * Do an RSA operation to sign the message digest
+ * Do an RSA operation, then remove the message padding
  */
-int rsa_pkcs1_sign( rsa_context *ctx,
-                    int mode,
-                    int hash_id,
-                    int hashlen,
-                    const unsigned char *hash,
-                    unsigned char *sig )
+int rsa_pkcs1_decrypt( rsa_context *ctx,
+                       int mode, size_t *olen,
+                       const unsigned char *input,
+                       unsigned char *output,
+                       size_t output_max_len)
 {
-    int nb_pad, olen;
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsaes_pkcs1_v15_decrypt( ctx, mode, olen, input, output,
+                                                output_max_len );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsaes_oaep_decrypt( ctx, mode, NULL, 0, olen, input,
+                                           output, output_max_len );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
+#if defined(POLARSSL_PKCS1_V21)
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PSS-SIGN function
+ */
+int rsa_rsassa_pss_sign( rsa_context *ctx,
+                         int (*f_rng)(void *, unsigned char *, size_t),
+                         void *p_rng,
+                         int mode,
+                         int hash_id,
+                         unsigned int hashlen,
+                         const unsigned char *hash,
+                         unsigned char *sig )
+{
+    size_t olen;
     unsigned char *p = sig;
+    unsigned char salt[POLARSSL_MD_MAX_SIZE];
+    unsigned int slen, hlen, offset = 0;
+    int ret;
+    size_t msb;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
 
-    switch( ctx->padding )
+    switch( hash_id )
     {
-        case RSA_PKCS_V15:
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            hashlen = 16;
+            break;
 
-            switch( hash_id )
-            {
-                case SIG_RSA_RAW:
-                    nb_pad = olen - 3 - hashlen;
-                    break;
+        case SIG_RSA_SHA1:
+            hashlen = 20;
+            break;
 
-                case SIG_RSA_MD2:
-                case SIG_RSA_MD4:
-                case SIG_RSA_MD5:
-                    nb_pad = olen - 3 - 34;
-                    break;
+        case SIG_RSA_SHA224:
+            hashlen = 28;
+            break;
+
+        case SIG_RSA_SHA256:
+            hashlen = 32;
+            break;
+
+        case SIG_RSA_SHA384:
+            hashlen = 48;
+            break;
+
+        case SIG_RSA_SHA512:
+            hashlen = 64;
+            break;
+
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    hlen = md_get_size( md_info );
+    slen = hlen;
+
+    if( olen < hlen + slen + 2 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    memset( sig, 0, olen );
+
+    msb = mpi_msb( &ctx->N ) - 1;
+
+    // Generate salt of length slen
+    //
+    if( ( ret = f_rng( p_rng, salt, slen ) ) != 0 )
+        return( POLARSSL_ERR_RSA_RNG_FAILED + ret );
+
+    // Note: EMSA-PSS encoding is over the length of N - 1 bits
+    //
+    msb = mpi_msb( &ctx->N ) - 1;
+    p += olen - hlen * 2 - 2;
+    *p++ = 0x01;
+    memcpy( p, salt, slen );
+    p += slen;
+
+    md_init_ctx( &md_ctx, md_info );
 
-                case SIG_RSA_SHA1:
-                    nb_pad = olen - 3 - 35;
-                    break;
+    // Generate H = Hash( M' )
+    //
+    md_starts( &md_ctx );
+    md_update( &md_ctx, p, 8 );
+    md_update( &md_ctx, hash, hashlen );
+    md_update( &md_ctx, salt, slen );
+    md_finish( &md_ctx, p );
 
-                case SIG_RSA_SHA224:
-                    nb_pad = olen - 3 - 47;
-                    break;
+    // Compensate for boundary condition when applying mask
+    //
+    if( msb % 8 == 0 )
+        offset = 1;
 
-                case SIG_RSA_SHA256:
-                    nb_pad = olen - 3 - 51;
-                    break;
+    // maskedDB: Apply dbMask to DB
+    //
+    mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
 
-                case SIG_RSA_SHA384:
-                    nb_pad = olen - 3 - 67;
-                    break;
+    md_free_ctx( &md_ctx );
 
-                case SIG_RSA_SHA512:
-                    nb_pad = olen - 3 - 83;
-                    break;
+    msb = mpi_msb( &ctx->N ) - 1;
+    sig[0] &= 0xFF >> ( olen * 8 - msb );
 
+    p += hlen;
+    *p++ = 0xBC;
 
-                default:
-                    return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
-            }
+    return( ( mode == RSA_PUBLIC )
+            ? rsa_public(  ctx, sig, sig )
+            : rsa_private( ctx, sig, sig ) );
+}
+#endif /* POLARSSL_PKCS1_V21 */
 
-            if( nb_pad < 8 )
-                return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-V1_5-SIGN function
+ */
+/*
+ * Do an RSA operation to sign the message digest
+ */
+int rsa_rsassa_pkcs1_v15_sign( rsa_context *ctx,
+                               int mode,
+                               int hash_id,
+                               unsigned int hashlen,
+                               const unsigned char *hash,
+                               unsigned char *sig )
+{
+    size_t nb_pad, olen;
+    unsigned char *p = sig;
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
-            *p++ = 0;
-            *p++ = RSA_SIGN;
-            memset( p, 0xFF, nb_pad );
-            p += nb_pad;
-            *p++ = 0;
+    olen = ctx->len;
+
+    switch( hash_id )
+    {
+        case SIG_RSA_RAW:
+            nb_pad = olen - 3 - hashlen;
             break;
 
-        default:
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            nb_pad = olen - 3 - 34;
+            break;
 
-            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        case SIG_RSA_SHA1:
+            nb_pad = olen - 3 - 35;
+            break;
+
+        case SIG_RSA_SHA224:
+            nb_pad = olen - 3 - 47;
+            break;
+
+        case SIG_RSA_SHA256:
+            nb_pad = olen - 3 - 51;
+            break;
+
+        case SIG_RSA_SHA384:
+            nb_pad = olen - 3 - 67;
+            break;
+
+        case SIG_RSA_SHA512:
+            nb_pad = olen - 3 - 83;
+            break;
+
+
+        default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
     }
 
+    if( ( nb_pad < 8 ) || ( nb_pad > olen ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    *p++ = 0;
+    *p++ = RSA_SIGN;
+    memset( p, 0xFF, nb_pad );
+    p += nb_pad;
+    *p++ = 0;
+
     switch( hash_id )
     {
         case SIG_RSA_RAW:
@@ -535,22 +975,62 @@ int rsa_pkcs1_sign( rsa_context *ctx,
 }
 
 /*
- * Do an RSA operation and check the message digest
+ * Do an RSA operation to sign the message digest
  */
-int rsa_pkcs1_verify( rsa_context *ctx,
-                      int mode,
-                      int hash_id,
-                      int hashlen,
-                      const unsigned char *hash,
-                      const unsigned char *sig )
+int rsa_pkcs1_sign( rsa_context *ctx,
+                    int (*f_rng)(void *, unsigned char *, size_t),
+                    void *p_rng,
+                    int mode,
+                    int hash_id,
+                    unsigned int hashlen,
+                    const unsigned char *hash,
+                    unsigned char *sig )
 {
-    int ret, len, siglen;
-    unsigned char *p, c;
-    unsigned char buf[256];
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsassa_pkcs1_v15_sign( ctx, mode, hash_id,
+                                              hashlen, hash, sig );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsassa_pss_sign( ctx, f_rng, p_rng, mode, hash_id,
+                                        hashlen, hash, sig );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
+#if defined(POLARSSL_PKCS1_V21)
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PSS-VERIFY function
+ */
+int rsa_rsassa_pss_verify( rsa_context *ctx,
+                           int mode,
+                           int hash_id,
+                           unsigned int hashlen,
+                           const unsigned char *hash,
+                           unsigned char *sig )
+{
+    int ret;
+    size_t siglen;
+    unsigned char *p;
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+    unsigned char result[POLARSSL_MD_MAX_SIZE];
+    unsigned char zeros[8];
+    unsigned int hlen;
+    size_t slen, msb;
+    const md_info_t *md_info;
+    md_context_t md_ctx;
+
+    if( ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     siglen = ctx->len;
 
-    if( siglen < 16 || siglen > (int) sizeof( buf ) )
+    if( siglen < 16 || siglen > sizeof( buf ) )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ret = ( mode == RSA_PUBLIC )
@@ -562,29 +1042,152 @@ int rsa_pkcs1_verify( rsa_context *ctx,
 
     p = buf;
 
-    switch( ctx->padding )
+    if( buf[siglen - 1] != 0xBC )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    switch( hash_id )
     {
-        case RSA_PKCS_V15:
+        case SIG_RSA_MD2:
+        case SIG_RSA_MD4:
+        case SIG_RSA_MD5:
+            hashlen = 16;
+            break;
 
-            if( *p++ != 0 || *p++ != RSA_SIGN )
-                return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        case SIG_RSA_SHA1:
+            hashlen = 20;
+            break;
 
-            while( *p != 0 )
-            {
-                if( p >= buf + siglen - 1 || *p != 0xFF )
-                    return( POLARSSL_ERR_RSA_INVALID_PADDING );
-                p++;
-            }
-            p++;
+        case SIG_RSA_SHA224:
+            hashlen = 28;
+            break;
+
+        case SIG_RSA_SHA256:
+            hashlen = 32;
+            break;
+
+        case SIG_RSA_SHA384:
+            hashlen = 48;
+            break;
+
+        case SIG_RSA_SHA512:
+            hashlen = 64;
             break;
 
         default:
+            return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+    }
+
+    md_info = md_info_from_type( ctx->hash_id );
+    if( md_info == NULL )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    hlen = md_get_size( md_info );
+    slen = siglen - hlen - 1;
+
+    memset( zeros, 0, 8 );
+
+    // Note: EMSA-PSS verification is over the length of N - 1 bits
+    //
+    msb = mpi_msb( &ctx->N ) - 1;
+
+    // Compensate for boundary condition when applying mask
+    //
+    if( msb % 8 == 0 )
+    {
+        p++;
+        siglen -= 1;
+    }
+    if( buf[0] >> ( 8 - siglen * 8 + msb ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    md_init_ctx( &md_ctx, md_info );
+
+    mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
+
+    buf[0] &= 0xFF >> ( siglen * 8 - msb );
+
+    while( *p == 0 && p < buf + siglen )
+        p++;
+
+    if( p == buf + siglen ||
+        *p++ != 0x01 )
+    {
+        md_free_ctx( &md_ctx );
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+
+    slen -= p - buf;
+
+    // Generate H = Hash( M' )
+    //
+    md_starts( &md_ctx );
+    md_update( &md_ctx, zeros, 8 );
+    md_update( &md_ctx, hash, hashlen );
+    md_update( &md_ctx, p, slen );
+    md_finish( &md_ctx, result );
+
+    md_free_ctx( &md_ctx );
+
+    if( memcmp( p + slen, result, hlen ) == 0 )
+        return( 0 );
+    else
+        return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+}
+#endif /* POLARSSL_PKCS1_V21 */
+
+/*
+ * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-v1_5-VERIFY function
+ */
+int rsa_rsassa_pkcs1_v15_verify( rsa_context *ctx,
+                                 int mode,
+                                 int hash_id,
+                                 unsigned int hashlen,
+                                 const unsigned char *hash,
+                                 unsigned char *sig )
+{
+    int ret;
+    size_t len, siglen;
+    unsigned char *p, c;
+    unsigned char buf[POLARSSL_MPI_MAX_SIZE];
+
+    if( ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
+    siglen = ctx->len;
+
+    if( siglen < 16 || siglen > sizeof( buf ) )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    ret = ( mode == RSA_PUBLIC )
+          ? rsa_public(  ctx, sig, buf )
+          : rsa_private( ctx, sig, buf );
+
+    if( ret != 0 )
+        return( ret );
+
+    p = buf;
+
+    if( *p++ != 0 || *p++ != RSA_SIGN )
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+
+    while( *p != 0 )
+    {
+        if( p >= buf + siglen - 1 || *p != 0xFF )
             return( POLARSSL_ERR_RSA_INVALID_PADDING );
+        p++;
     }
+    p++;
 
-    len = siglen - (int)( p - buf );
+    len = siglen - ( p - buf );
 
+    if( len == 33 && hash_id == SIG_RSA_SHA1 )
+    {
+        if( memcmp( p, ASN1_HASH_SHA1_ALT, 13 ) == 0 &&
+                memcmp( p + 13, hash, 20 ) == 0 )
+            return( 0 );
+        else
+            return( POLARSSL_ERR_RSA_VERIFY_FAILED );
+    }
     if( len == 34 )
     {
         c = p[13];
@@ -594,10 +1197,10 @@ int rsa_pkcs1_verify( rsa_context *ctx,
             return( POLARSSL_ERR_RSA_VERIFY_FAILED );
 
         if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
-            ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
-            ( c == 5 && hash_id == SIG_RSA_MD5 ) )
+                ( c == 4 && hash_id == SIG_RSA_MD4 ) ||
+                ( c == 5 && hash_id == SIG_RSA_MD5 ) )
         {
-            if( memcmp( p + 18, hash, 16 ) == 0 ) 
+            if( memcmp( p + 18, hash, 16 ) == 0 )
                 return( 0 );
             else
                 return( POLARSSL_ERR_RSA_VERIFY_FAILED );
@@ -607,17 +1210,17 @@ int rsa_pkcs1_verify( rsa_context *ctx,
     if( len == 35 && hash_id == SIG_RSA_SHA1 )
     {
         if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
-            memcmp( p + 15, hash, 20 ) == 0 )
+                memcmp( p + 15, hash, 20 ) == 0 )
             return( 0 );
         else
             return( POLARSSL_ERR_RSA_VERIFY_FAILED );
     }
     if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
-        ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
-        ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
-        ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
+            ( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
+            ( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
+            ( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
     {
-       c = p[1] - 17;
+        c = p[1] - 17;
         p[1] = 17;
         p[14] = 0;
 
@@ -640,15 +1243,42 @@ int rsa_pkcs1_verify( rsa_context *ctx,
     return( POLARSSL_ERR_RSA_INVALID_PADDING );
 }
 
+/*
+ * Do an RSA operation and check the message digest
+ */
+int rsa_pkcs1_verify( rsa_context *ctx,
+                      int mode,
+                      int hash_id,
+                      unsigned int hashlen,
+                      const unsigned char *hash,
+                      unsigned char *sig )
+{
+    switch( ctx->padding )
+    {
+        case RSA_PKCS_V15:
+            return rsa_rsassa_pkcs1_v15_verify( ctx, mode, hash_id,
+                                                hashlen, hash, sig );
+
+#if defined(POLARSSL_PKCS1_V21)
+        case RSA_PKCS_V21:
+            return rsa_rsassa_pss_verify( ctx, mode, hash_id,
+                                          hashlen, hash, sig );
+#endif
+
+        default:
+            return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+}
+
 /*
  * Free the components of an RSA key
  */
 void rsa_free( rsa_context *ctx )
 {
-    mpi_free( &ctx->RQ, &ctx->RP, &ctx->RN,
-              &ctx->QP, &ctx->DQ, &ctx->DP,
-              &ctx->Q,  &ctx->P,  &ctx->D,
-              &ctx->E,  &ctx->N,  NULL );
+    mpi_free( &ctx->RQ ); mpi_free( &ctx->RP ); mpi_free( &ctx->RN );
+    mpi_free( &ctx->QP ); mpi_free( &ctx->DQ ); mpi_free( &ctx->DP );
+    mpi_free( &ctx->Q  ); mpi_free( &ctx->P  ); mpi_free( &ctx->D );
+    mpi_free( &ctx->E  ); mpi_free( &ctx->N  );
 }
 
 #if defined(POLARSSL_SELF_TEST)
@@ -709,12 +1339,17 @@ void rsa_free( rsa_context *ctx )
 #define RSA_PT  "\xAA\xBB\xCC\x03\x02\x01\x00\xFF\xFF\xFF\xFF\xFF" \
                 "\x11\x22\x33\x0A\x0B\x0C\xCC\xDD\xDD\xDD\xDD\xDD"
 
-static int myrand( void *rng_state )
+static int myrand( void *rng_state, unsigned char *output, size_t len )
 {
+    size_t i;
+
     if( rng_state != NULL )
         rng_state  = NULL;
 
-    return( rand() );
+    for( i = 0; i < len; ++i )
+        output[i] = rand();
+    
+    return( 0 );
 }
 
 /*
@@ -722,12 +1357,14 @@ static int myrand( void *rng_state )
  */
 int rsa_self_test( int verbose )
 {
-    int len;
+    size_t len;
     rsa_context rsa;
-    unsigned char sha1sum[20];
     unsigned char rsa_plaintext[PT_LEN];
     unsigned char rsa_decrypted[PT_LEN];
     unsigned char rsa_ciphertext[KEY_LEN];
+#if defined(POLARSSL_SHA1_C)
+    unsigned char sha1sum[20];
+#endif
 
     rsa_init( &rsa, RSA_PKCS_V15, 0 );
 
@@ -772,7 +1409,7 @@ int rsa_self_test( int verbose )
 
     if( rsa_pkcs1_decrypt( &rsa, RSA_PRIVATE, &len,
                            rsa_ciphertext, rsa_decrypted,
-                          sizeof(rsa_decrypted) ) != 0 )
+                           sizeof(rsa_decrypted) ) != 0 )
     {
         if( verbose != 0 )
             printf( "failed\n" );
@@ -788,12 +1425,13 @@ int rsa_self_test( int verbose )
         return( 1 );
     }
 
+#if defined(POLARSSL_SHA1_C)
     if( verbose != 0 )
         printf( "passed\n  PKCS#1 data sign  : " );
 
     sha1( rsa_plaintext, PT_LEN, sha1sum );
 
-    if( rsa_pkcs1_sign( &rsa, RSA_PRIVATE, SIG_RSA_SHA1, 20,
+    if( rsa_pkcs1_sign( &rsa, NULL, NULL, RSA_PRIVATE, SIG_RSA_SHA1, 20,
                         sha1sum, rsa_ciphertext ) != 0 )
     {
         if( verbose != 0 )
@@ -816,6 +1454,7 @@ int rsa_self_test( int verbose )
 
     if( verbose != 0 )
         printf( "passed\n\n" );
+#endif /* POLARSSL_SHA1_C */
 
     rsa_free( &rsa );