mpi_exp_mod: Use stack instead of malloc
authorNIIBE Yutaka <gniibe@fsij.org>
Mon, 23 Dec 2013 07:15:39 +0000 (16:15 +0900)
committerNIIBE Yutaka <gniibe@fsij.org>
Mon, 23 Dec 2013 07:17:20 +0000 (16:17 +0900)
ChangeLog
doc/note/HACKING
polarssl/library/bignum.c
src/gnuk.ld.in

index 019669f..dda577c 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -2,6 +2,11 @@
 
        * polarssl/library/bignum.c (mpi_montmul): Computation
        time should not depends on input.
+       (mpi_montmul, mpi_montred, mpi_montsqr): Change the API.
+       (mpi_exp_mod): Follow the change of the API.  Allocate memory on
+       stack instead of malloc.
+
+       * src/gnuk.ld.in (__process3_stack_size__): Increase stack size.
 
 2013-12-20  Niibe Yutaka  <gniibe@fsij.org>
 
index fc5be9a..6d54b5e 100644 (file)
@@ -5,7 +5,7 @@ It is important to collect enough entropy.  Perhaps, it would
 be possible to get entropy from USB traffic (of other devices).
 
 
-* RSA
+* [Mostly DONE] RSA
 
 It would be good not to use malloc.
 
index 2a8c904..da9a069 100644 (file)
@@ -1382,14 +1382,13 @@ static void mpi_montg_init( t_uint *mm, const mpi *N )
 
 /*
  * Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
- * A is placed at the upper half of T.
+ * A is placed at the upper half of D.
  */
-static void mpi_montmul( const mpi *B, const mpi *N, t_uint mm, mpi *T )
+static void mpi_montmul( const t_uint *bp, const mpi *N, t_uint mm, t_uint *d )
 {
     size_t i, n;
-    t_uint u0, u1, *d, c = 0;
+    t_uint u0, u1, c = 0;
 
-    d = T->p;
     n = N->n;
 
     for( i = 0; i < n; i++ )
@@ -1399,9 +1398,9 @@ static void mpi_montmul( const mpi *B, const mpi *N, t_uint mm, mpi *T )
          */
         u0 = d[n];
         d[n] = c;
-        u1 = ( d[0] + u0 * B->p[0] ) * mm;
+        u1 = ( d[0] + u0 * bp[0] ) * mm;
 
-        mpi_mul_hlp( n, B->p, d, u0 );
+        mpi_mul_hlp( n, bp, d, u0 );
         c = mpi_mul_hlp( n, N->p, d, u1 );
         d++;
     }
@@ -1410,19 +1409,18 @@ static void mpi_montmul( const mpi *B, const mpi *N, t_uint mm, mpi *T )
     if( ((mpi_cmp_abs_limbs ( n, d, N->p ) >= 0) | c) )
         mpi_sub_hlp( n, N->p, d );
     else
-        mpi_sub_hlp( n, T->p, T->p);
+        mpi_sub_hlp( n, d - n, d - n);
 }
 
 /*
  * Montgomery reduction: A = A * R^-1 mod N
- * A is placed at the upper half of T.
+ * A is placed at the upper half of D.
  */
-static void mpi_montred( const mpi *N, t_uint mm, mpi *T )
+static void mpi_montred( const mpi *N, t_uint mm, t_uint *d )
 {
     size_t i, j, n;
-    t_uint u0, u1, *d, c = 0;
+    t_uint u0, u1, c = 0;
 
-    d = T->p;
     n = N->n;
 
     for( i = 0; i < n; i++ )
@@ -1449,19 +1447,18 @@ static void mpi_montred( const mpi *N, t_uint mm, mpi *T )
     if( ((mpi_cmp_abs_limbs ( n, d, N->p ) >= 0) | c) )
         mpi_sub_hlp( n, N->p, d );
     else
-        mpi_sub_hlp( n, T->p, T->p);
+        mpi_sub_hlp( n, d - n, d - n);
 }
 
 /*
  * Montgomery square: A = A * A * R^-1 mod N
- * A is placed at the upper half of T.
+ * A is placed at the upper half of D.
  */
-static void mpi_montsqr( const mpi *N, t_uint mm, mpi *T )
+static void mpi_montsqr( const mpi *N, t_uint mm, t_uint *d )
 {
   size_t n, i;
-  t_uint c = 0, *d;
+  t_uint c = 0;
 
-  d = T->p;
   n = N->n;
 
   for (i = 0; i < n; i++)
@@ -1544,13 +1541,13 @@ static void mpi_montsqr( const mpi *N, t_uint mm, mpi *T )
         c += mpi_mul_hlp( n, N->p, &d[i], u );
     }
 
-  d = T->p + n;
+  d += n;
 
   /* prevent timing attacks */
   if( ((mpi_cmp_abs_limbs ( n, d, N->p ) >= 0) | c) )
       mpi_sub_hlp( n, N->p, d );
   else
-      mpi_sub_hlp( n, T->p, T->p);
+      mpi_sub_hlp( n, d - n, d - n);
 }
 
 /*
@@ -1559,12 +1556,17 @@ static void mpi_montsqr( const mpi *N, t_uint mm, mpi *T )
 int mpi_exp_mod( mpi *X, const mpi *A, const mpi *E, const mpi *N, mpi *_RR )
 {
     int ret;
-    size_t wbits, wsize, one = 1;
-    size_t i, j, nblimbs;
+    size_t i = mpi_msb( E );
+    size_t wsize = ( i > 671 ) ? 6 : ( i > 239 ) ? 5 :
+                   ( i >  79 ) ? 4 : ( i >  23 ) ? 3 : 1;
+    size_t wbits, one = 1;
+    size_t nblimbs;
     size_t bufsize, nbits;
     t_uint ei, mm, state;
-    mpi RR, T, W[ 2 << POLARSSL_MPI_WINDOW_SIZE ], Apos;
-    int neg;
+    mpi RR;
+    t_uint d[N->n*2];
+    t_uint w1[N->n];
+    t_uint wn[(one << (wsize - 1))][N->n];
 
     if( mpi_cmp_int( N, 0 ) < 0 || ( N->p[0] & 1 ) == 0 )
         return( POLARSSL_ERR_MPI_BAD_INPUT_DATA );
@@ -1572,97 +1574,76 @@ int mpi_exp_mod( mpi *X, const mpi *A, const mpi *E, const mpi *N, mpi *_RR )
     if( mpi_cmp_int( E, 0 ) < 0 )
         return( POLARSSL_ERR_MPI_BAD_INPUT_DATA );
 
+    if( A->s == -1 )
+        return( POLARSSL_ERR_MPI_BAD_INPUT_DATA );
+
     /*
      * Init temps and window size
      */
     mpi_montg_init( &mm, N );
-    mpi_init( &RR ); mpi_init( &T );
-    memset( W, 0, sizeof( W ) );
-
-    i = mpi_msb( E );
-
-    wsize = ( i > 671 ) ? 6 : ( i > 239 ) ? 5 :
-            ( i >  79 ) ? 4 : ( i >  23 ) ? 3 : 1;
-
-    if( wsize > POLARSSL_MPI_WINDOW_SIZE )
-        wsize = POLARSSL_MPI_WINDOW_SIZE;
-
-    j = N->n;
     MPI_CHK( mpi_grow( X, N->n ) );
-    MPI_CHK( mpi_grow( &W[1],  N->n ) );
-    MPI_CHK( mpi_grow( &T, N->n * 2 ) ); /* T = 0 here.  */
-
-    /*
-     * Compensate for negative A (and correct at the end)
-     */
-    neg = ( A->s == -1 );
-
-    mpi_init( &Apos );
-    if( neg )
-    {
-        MPI_CHK( mpi_copy( &Apos, A ) );
-        Apos.s = 1;
-        A = &Apos;
-    }
 
     /*
      * If 1st call, pre-compute R^2 mod N
      */
     if( _RR == NULL || _RR->p == NULL )
     {
-        /* T->p is all zero here. */
-        mpi_sub_hlp( N->n, N->p, T.p + N->n);
+        mpi T;
+
+        mpi_init( &RR );
+        T.s = 1; T.n = N->n * 2; T.p = d;
+        memset (d, 0, 2 * N->n * ciL); /* Set D zero. */
+        mpi_sub_hlp( N->n, N->p, d + N->n);
         MPI_CHK( mpi_mod_mpi( &RR, &T, N ) );
 
         if( _RR != NULL )
             memcpy( _RR, &RR, sizeof( mpi ) );
 
-        /* The condition of "the lower half of T is all zero" is kept. */
+        /* The condition of "the lower half of D is all zero" is kept. */
     }
-    else
+    else {
         memcpy( &RR, _RR, sizeof( mpi ) );
+        memset (d, 0, N->n * ciL); /* Set lower half of D zero. */
+    }
 
     /*
      * W[1] = A * R^2 * R^-1 mod N = A * R mod N
      */
-    if( mpi_cmp_mpi( A, N ) >= 0 )
-        mpi_mod_mpi( &W[1], A, N );
-    else   mpi_copy( &W[1], A );
+    if( mpi_cmp_mpi( A, N ) >= 0 ) {
+        mpi W1;
+        W1.s = 1; W1.n = N->n; W1.p = d + N->n;
+        mpi_mod_mpi( &W1, A, N );
+    } else {
+        memset (d + N->n, 0, N->n * ciL);
+        memcpy (d + N->n, A->p, A->n * ciL);
+    }
 
-    memcpy ( T.p + N->n, W[1].p, N->n * ciL);
-    mpi_montmul( &RR, N, mm, &T );
-    memcpy ( W[1].p, T.p + N->n, N->n * ciL);
+    mpi_montmul( RR.p, N, mm, d );
+    memcpy (w1, d + N->n, N->n * ciL);
 
-    if( wsize > 1 )
     {
         /*
          * W[1 << (wsize - 1)] = W[1] ^ ( 2 ^ (wsize - 1) )
          */
-        j =  one << (wsize - 1);
-
-        MPI_CHK( mpi_grow( &W[j], N->n  ) );
-
         for( i = 0; i < wsize - 1; i++ )
-            mpi_montsqr( N, mm, &T );
-        memcpy ( W[j].p, T.p + N->n, N->n * ciL);
+            mpi_montsqr( N, mm, d );
+        memcpy (wn[0], d + N->n, N->n * ciL);
 
         /*
          * W[i] = W[i - 1] * W[1]
          */
-        for( i = j + 1; i < (one << wsize); i++ )
+        for( i = 1; i < (one << (wsize - 1)); i++ )
         {
-            MPI_CHK( mpi_grow( &W[i], N->n      ) );
-
-            mpi_montmul( &W[1], N, mm, &T );
-            memcpy ( W[i].p, T.p + N->n, N->n * ciL);
+            mpi_montmul( w1, N, mm, d );
+            memcpy (wn[i], d + N->n, N->n * ciL);
         }
     }
 
     /*
      * X = R^2 * R^-1 mod N = R mod N
      */
-    memcpy ( T.p + N->n, RR.p, N->n * ciL);
-    mpi_montred( N, mm, &T );
+    memcpy (d + N->n, RR.p, N->n * ciL);
+    mpi_montred( N, mm, d );
 
     nblimbs = E->n;
     bufsize = 0;
@@ -1695,7 +1676,7 @@ int mpi_exp_mod( mpi *X, const mpi *A, const mpi *E, const mpi *N, mpi *_RR )
             /*
              * out of window, square X
              */
-            mpi_montsqr( N, mm, &T );
+            mpi_montsqr( N, mm, d );
             continue;
         }
 
@@ -1713,12 +1694,12 @@ int mpi_exp_mod( mpi *X, const mpi *A, const mpi *E, const mpi *N, mpi *_RR )
              * X = X^wsize R^-1 mod N
              */
             for( i = 0; i < wsize; i++ )
-                mpi_montsqr( N, mm, &T );
+                mpi_montsqr( N, mm, d );
 
             /*
              * X = X * W[wbits] R^-1 mod N
              */
-            mpi_montmul( &W[wbits], N, mm, &T );
+            mpi_montmul( wn[wbits - (one << (wsize - 1))], N, mm, d );
 
             state--;
             nbits = 0;
@@ -1731,33 +1712,22 @@ int mpi_exp_mod( mpi *X, const mpi *A, const mpi *E, const mpi *N, mpi *_RR )
      */
     for( i = 0; i < nbits; i++ )
     {
-        mpi_montsqr( N, mm, &T );
+        mpi_montsqr( N, mm, d );
 
         wbits <<= 1;
 
         if( (wbits & (one << wsize)) != 0 )
-            mpi_montmul( &W[1], N, mm, &T );
+            mpi_montmul( w1, N, mm, d );
     }
 
     /*
      * X = A^E * R * R^-1 mod N = A^E mod N
      */
-    mpi_montred( N, mm, &T );
-    memcpy ( X->p, T.p + N->n, N->n * ciL);
-
-    if( neg )
-    {
-        X->s = -1;
-        mpi_add_mpi( X, N, X );
-    }
+    mpi_montred( N, mm, d );
+    memcpy (X->p, d + N->n, N->n * ciL);
 
 cleanup:
 
-    for( i = (one << (wsize - 1)); i < (one << wsize); i++ )
-        mpi_free( &W[i] );
-
-    mpi_free( &W[1] ); mpi_free( &T ); mpi_free( &Apos );
-
     if( _RR == NULL )
         mpi_free( &RR );
 
index 5d4f90b..a7d19de 100644 (file)
@@ -5,7 +5,7 @@ __main_stack_size__      = 0x0100;      /* Exception handlers     */
 __process0_stack_size__  = 0x0100;      /* main */
 __process1_stack_size__  = 0x0140;      /* ccid */
 __process2_stack_size__  = 0x0180;      /* rng */
-__process3_stack_size__  = 0x0b00;      /* gpg */
+__process3_stack_size__  = 0x1600;      /* gpg */
 __process4_stack_size__  = 0x0100;      /* intr: usb */
 __process5_stack_size__  = @MSC_SIZE@;  /* msc */
 __process6_stack_size__  = @TIM_SIZE@;  /* intr: timer */