s4:heimdal: import lorikeet-heimdal-202201172009 (commit 5a0b45cd723628b3690ea848548b...
[samba.git] / source4 / heimdal / lib / hcrypto / libtommath / bn_mp_prime_rand.c
1 #include "tommath_private.h"
2 #ifdef BN_MP_PRIME_RAND_C
3 /* LibTomMath, multiple-precision integer library -- Tom St Denis */
4 /* SPDX-License-Identifier: Unlicense */
5
6 /* makes a truly random prime of a given size (bits),
7  *
8  * Flags are as follows:
9  *
10  *   MP_PRIME_BBS      - make prime congruent to 3 mod 4
11  *   MP_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies MP_PRIME_BBS)
12  *   MP_PRIME_2MSB_ON  - make the 2nd highest bit one
13  *
14  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
15  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
16  * so it can be NULL
17  *
18  */
19
20 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
21 mp_err s_mp_prime_random_ex(mp_int *a, int t, int size, int flags, private_mp_prime_callback cb, void *dat)
22 {
23    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
24    int bsize, maskOR_msb_offset;
25    mp_bool res;
26    mp_err err;
27
28    /* sanity check the input */
29    if ((size <= 1) || (t <= 0)) {
30       return MP_VAL;
31    }
32
33    /* MP_PRIME_SAFE implies MP_PRIME_BBS */
34    if ((flags & MP_PRIME_SAFE) != 0) {
35       flags |= MP_PRIME_BBS;
36    }
37
38    /* calc the byte size */
39    bsize = (size>>3) + ((size&7)?1:0);
40
41    /* we need a buffer of bsize bytes */
42    tmp = (unsigned char *) MP_MALLOC((size_t)bsize);
43    if (tmp == NULL) {
44       return MP_MEM;
45    }
46
47    /* calc the maskAND value for the MSbyte*/
48    maskAND = ((size&7) == 0) ? 0xFFu : (unsigned char)(0xFFu >> (8 - (size & 7)));
49
50    /* calc the maskOR_msb */
51    maskOR_msb        = 0;
52    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
53    if ((flags & MP_PRIME_2MSB_ON) != 0) {
54       maskOR_msb       |= (unsigned char)(0x80 >> ((9 - size) & 7));
55    }
56
57    /* get the maskOR_lsb */
58    maskOR_lsb         = 1u;
59    if ((flags & MP_PRIME_BBS) != 0) {
60       maskOR_lsb     |= 3u;
61    }
62
63    do {
64       /* read the bytes */
65       if (cb(tmp, bsize, dat) != bsize) {
66          err = MP_VAL;
67          goto error;
68       }
69
70       /* work over the MSbyte */
71       tmp[0]    &= maskAND;
72       tmp[0]    |= (unsigned char)(1 << ((size - 1) & 7));
73
74       /* mix in the maskORs */
75       tmp[maskOR_msb_offset]   |= maskOR_msb;
76       tmp[bsize-1]             |= maskOR_lsb;
77
78       /* read it in */
79       /* TODO: casting only for now until all lengths have been changed to the type "size_t"*/
80       if ((err = mp_from_ubin(a, tmp, (size_t)bsize)) != MP_OKAY) {
81          goto error;
82       }
83
84       /* is it prime? */
85       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) {
86          goto error;
87       }
88       if (res == MP_NO) {
89          continue;
90       }
91
92       if ((flags & MP_PRIME_SAFE) != 0) {
93          /* see if (a-1)/2 is prime */
94          if ((err = mp_sub_d(a, 1uL, a)) != MP_OKAY) {
95             goto error;
96          }
97          if ((err = mp_div_2(a, a)) != MP_OKAY) {
98             goto error;
99          }
100
101          /* is it prime? */
102          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) {
103             goto error;
104          }
105       }
106    } while (res == MP_NO);
107
108    if ((flags & MP_PRIME_SAFE) != 0) {
109       /* restore a to the original value */
110       if ((err = mp_mul_2(a, a)) != MP_OKAY) {
111          goto error;
112       }
113       if ((err = mp_add_d(a, 1uL, a)) != MP_OKAY) {
114          goto error;
115       }
116    }
117
118    err = MP_OKAY;
119 error:
120    MP_FREE_BUFFER(tmp, (size_t)bsize);
121    return err;
122 }
123
124 static int s_mp_rand_cb(unsigned char *dst, int len, void *dat)
125 {
126    (void)dat;
127    if (len <= 0) {
128       return len;
129    }
130    if (s_mp_rand_source(dst, (size_t)len) != MP_OKAY) {
131       return 0;
132    }
133    return len;
134 }
135
136 mp_err mp_prime_rand(mp_int *a, int t, int size, int flags)
137 {
138    return s_mp_prime_random_ex(a, t, size, flags, s_mp_rand_cb, NULL);
139 }
140
141 #endif