s4:torture: Adapt KDC canon test to Heimdal upstream changes
[samba.git] / third_party / heimdal / lib / hcrypto / libtommath / etc / tune.c
1 /* Tune the Karatsuba parameters
2  *
3  * Tom St Denis, tstdenis82@gmail.com
4  */
5 #include "../tommath.h"
6 #include "../tommath_private.h"
7 #include <time.h>
8 #include <inttypes.h>
9 #include <errno.h>
10
11 /*
12    Please take in mind that both multiplicands are of the same size. The balancing
13    mechanism in mp_balance works well but has some overhead itself. You can test
14    the behaviour of it with the option "-o" followed by a (small) positive number 'x'
15    to generate ratios of the form 1:x.
16 */
17
18 static uint64_t s_timer_function(void);
19 static void s_timer_start(void);
20 static uint64_t s_timer_stop(void);
21 static uint64_t s_time_mul(int size);
22 static uint64_t s_time_sqr(int size);
23 static void s_usage(char *s);
24
25 static uint64_t s_timer_function(void)
26 {
27 #if _POSIX_C_SOURCE >= 199309L
28 #define LTM_BILLION 1000000000
29    struct timespec ts;
30
31    /* TODO: Sets errno in case of error. Use? */
32    clock_gettime(CLOCK_MONOTONIC, &ts);
33    return (((uint64_t)ts.tv_sec) * LTM_BILLION + (uint64_t)ts.tv_nsec);
34 #else
35    clock_t t;
36    t = clock();
37    if (t < (clock_t)(0)) {
38       return (uint64_t)(0);
39    }
40    return (uint64_t)(t);
41 #endif
42 }
43
44 /* generic ISO C timer */
45 static uint64_t s_timer_tmp;
46 static void s_timer_start(void)
47 {
48    s_timer_tmp = s_timer_function();
49 }
50 static uint64_t s_timer_stop(void)
51 {
52    return s_timer_function() - s_timer_tmp;
53 }
54
55
56 static int s_check_result;
57 static int s_number_of_test_loops;
58 static int s_stabilization_extra;
59 static int s_offset = 1;
60
61 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
62 static uint64_t s_time_mul(int size)
63 {
64    int x;
65    mp_err  e;
66    mp_int  a, b, c, d;
67    uint64_t t1;
68
69    if ((e = mp_init_multi(&a, &b, &c, &d, NULL)) != MP_OKAY) {
70       t1 = UINT64_MAX;
71       goto LTM_ERR;
72    }
73
74    if ((e = mp_rand(&a, size * s_offset)) != MP_OKAY) {
75       t1 = UINT64_MAX;
76       goto LTM_ERR;
77    }
78    if ((e = mp_rand(&b, size)) != MP_OKAY) {
79       t1 = UINT64_MAX;
80       goto LTM_ERR;
81    }
82
83    s_timer_start();
84    for (x = 0; x < s_number_of_test_loops; x++) {
85       if ((e = mp_mul(&a,&b,&c)) != MP_OKAY) {
86          t1 = UINT64_MAX;
87          goto LTM_ERR;
88       }
89       if (s_check_result == 1) {
90          if ((e = s_mp_mul(&a,&b,&d)) != MP_OKAY) {
91             t1 = UINT64_MAX;
92             goto LTM_ERR;
93          }
94          if (mp_cmp(&c, &d) != MP_EQ) {
95             /* Time of 0 cannot happen (famous last words?) */
96             t1 = 0uLL;
97             goto LTM_ERR;
98          }
99       }
100    }
101
102    t1 = s_timer_stop();
103 LTM_ERR:
104    mp_clear_multi(&a, &b, &c, &d, NULL);
105    return t1;
106 }
107
108 static uint64_t s_time_sqr(int size)
109 {
110    int x;
111    mp_err  e;
112    mp_int  a, b, c;
113    uint64_t t1;
114
115    if ((e = mp_init_multi(&a, &b, &c, NULL)) != MP_OKAY) {
116       t1 = UINT64_MAX;
117       goto LTM_ERR;
118    }
119
120    if ((e = mp_rand(&a, size)) != MP_OKAY) {
121       t1 = UINT64_MAX;
122       goto LTM_ERR;
123    }
124
125    s_timer_start();
126    for (x = 0; x < s_number_of_test_loops; x++) {
127       if ((e = mp_sqr(&a,&b)) != MP_OKAY) {
128          t1 = UINT64_MAX;
129          goto LTM_ERR;
130       }
131       if (s_check_result == 1) {
132          if ((e = s_mp_sqr(&a,&c)) != MP_OKAY) {
133             t1 = UINT64_MAX;
134             goto LTM_ERR;
135          }
136          if (mp_cmp(&c, &b) != MP_EQ) {
137             t1 = 0uLL;
138             goto LTM_ERR;
139          }
140       }
141    }
142
143    t1 = s_timer_stop();
144 LTM_ERR:
145    mp_clear_multi(&a, &b, &c, NULL);
146    return t1;
147 }
148
149 struct tune_args {
150    int testmode;
151    int verbose;
152    int print;
153    int bncore;
154    int terse;
155    int upper_limit_print;
156    int increment_print;
157 } args;
158
159 static void s_run(const char *name, uint64_t (*op)(int), int *cutoff)
160 {
161    int x, count = 0;
162    uint64_t t1, t2;
163    if ((args.verbose == 1) || (args.testmode == 1)) {
164       printf("# %s.\n", name);
165    }
166    for (x = 8; x < args.upper_limit_print; x += args.increment_print) {
167       *cutoff = INT_MAX;
168       t1 = op(x);
169       if ((t1 == 0uLL) || (t1 == UINT64_MAX)) {
170          fprintf(stderr,"%s failed at x = INT_MAX (%s)\n", name,
171                  (t1 == 0uLL)?"wrong result":"internal error");
172          exit(EXIT_FAILURE);
173       }
174       *cutoff = x;
175       t2 = op(x);
176       if ((t2 == 0uLL) || (t2 == UINT64_MAX)) {
177          fprintf(stderr,"%s failed (%s)\n", name,
178                  (t2 == 0uLL)?"wrong result":"internal error");
179          exit(EXIT_FAILURE);
180       }
181       if (args.verbose == 1) {
182          printf("%d: %9"PRIu64" %9"PRIu64", %9"PRIi64"\n", x, t1, t2, (int64_t)t2 - (int64_t)t1);
183       }
184       if (t2 < t1) {
185          if (count == s_stabilization_extra) {
186             count = 0;
187             break;
188          } else if (count < s_stabilization_extra) {
189             count++;
190          }
191       } else if (count > 0) {
192          count--;
193       }
194    }
195    *cutoff = x - s_stabilization_extra * args.increment_print;
196 }
197
198 static long s_strtol(const char *str, char **endptr, const char *err)
199 {
200    const int base = 10;
201    char *_endptr;
202    long val;
203    errno = 0;
204    val = strtol(str, &_endptr, base);
205    if ((val > INT_MAX || val < 0) || (errno != 0)) {
206       fprintf(stderr, "Value %s not usable\n", str);
207       exit(EXIT_FAILURE);
208    }
209    if (_endptr == str) {
210       fprintf(stderr, "%s\n", err);
211       exit(EXIT_FAILURE);
212    }
213    if (endptr) *endptr = _endptr;
214    return val;
215 }
216
217 static int s_exit_code = EXIT_FAILURE;
218 static void s_usage(char *s)
219 {
220    fprintf(stderr,"Usage: %s [TvcpGbtrSLFfMmosh]\n",s);
221    fprintf(stderr,"          -T testmode, for use with testme.sh\n");
222    fprintf(stderr,"          -v verbose, print all timings\n");
223    fprintf(stderr,"          -c check results\n");
224    fprintf(stderr,"          -p print benchmark of final cutoffs in files \"multiplying\"\n");
225    fprintf(stderr,"             and \"squaring\"\n");
226    fprintf(stderr,"          -G [string] suffix for the filenames listed above\n");
227    fprintf(stderr,"             Implies '-p'\n");
228    fprintf(stderr,"          -b print benchmark of bncore.c\n");
229    fprintf(stderr,"          -t prints space (0x20) separated results\n");
230    fprintf(stderr,"          -r [64] number of rounds\n");
231    fprintf(stderr,"          -S [0xdeadbeef] seed for PRNG\n");
232    fprintf(stderr,"          -L [3] number of negative values accumulated until the result is accepted\n");
233    fprintf(stderr,"          -M [3000] upper limit of T-C tests/prints\n");
234    fprintf(stderr,"          -m [1] increment of T-C tests/prints\n");
235    fprintf(stderr,"          -o [1] multiplier for the second multiplicand\n");
236    fprintf(stderr,"             (Not for computing the cut-offs!)\n");
237    fprintf(stderr,"          -s 'preset' use values in 'preset' for printing.\n");
238    fprintf(stderr,"             'preset' is a comma separated string with cut-offs for\n");
239    fprintf(stderr,"             ksm, kss, tc3m, tc3s in that order\n");
240    fprintf(stderr,"             ksm  = karatsuba multiplication\n");
241    fprintf(stderr,"             kss  = karatsuba squaring\n");
242    fprintf(stderr,"             tc3m = Toom-Cook 3-way multiplication\n");
243    fprintf(stderr,"             tc3s = Toom-Cook 3-way squaring\n");
244    fprintf(stderr,"             Implies '-p'\n");
245    fprintf(stderr,"          -h this message\n");
246    exit(s_exit_code);
247 }
248
249 struct cutoffs {
250    int KARATSUBA_MUL, KARATSUBA_SQR;
251    int TOOM_MUL, TOOM_SQR;
252 };
253
254 const struct cutoffs max_cutoffs =
255 { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
256
257 static void set_cutoffs(const struct cutoffs *c)
258 {
259    KARATSUBA_MUL_CUTOFF = c->KARATSUBA_MUL;
260    KARATSUBA_SQR_CUTOFF = c->KARATSUBA_SQR;
261    TOOM_MUL_CUTOFF = c->TOOM_MUL;
262    TOOM_SQR_CUTOFF = c->TOOM_SQR;
263 }
264
265 static void get_cutoffs(struct cutoffs *c)
266 {
267    c->KARATSUBA_MUL  = KARATSUBA_MUL_CUTOFF;
268    c->KARATSUBA_SQR  = KARATSUBA_SQR_CUTOFF;
269    c->TOOM_MUL = TOOM_MUL_CUTOFF;
270    c->TOOM_SQR = TOOM_SQR_CUTOFF;
271
272 }
273
274 int main(int argc, char **argv)
275 {
276    uint64_t t1, t2;
277    int x, i, j;
278    size_t n;
279
280    int printpreset = 0;
281    /*int preset[8];*/
282    char *endptr, *str;
283
284    uint64_t seed = 0xdeadbeef;
285
286    int opt;
287    struct cutoffs orig, updated;
288
289    FILE *squaring, *multiplying;
290    char mullog[256] = "multiplying";
291    char sqrlog[256] = "squaring";
292    s_number_of_test_loops = 64;
293    s_stabilization_extra = 3;
294
295    MP_ZERO_BUFFER(&args, sizeof(args));
296
297    args.testmode = 0;
298    args.verbose = 0;
299    args.print = 0;
300    args.bncore = 0;
301    args.terse = 0;
302
303    args.upper_limit_print = 3000;
304    args.increment_print = 1;
305
306    /* Very simple option parser, please treat it nicely. */
307    if (argc != 1) {
308       for (opt = 1; (opt < argc) && (argv[opt][0] == '-'); opt++) {
309          switch (argv[opt][1]) {
310          case 'T':
311             args.testmode = 1;
312             s_check_result = 1;
313             args.upper_limit_print = 1000;
314             args.increment_print = 11;
315             s_number_of_test_loops = 1;
316             s_stabilization_extra = 1;
317             s_offset = 1;
318             break;
319          case 'v':
320             args.verbose = 1;
321             break;
322          case 'c':
323             s_check_result = 1;
324             break;
325          case 'p':
326             args.print = 1;
327             break;
328          case 'G':
329             args.print = 1;
330             opt++;
331             if (opt >= argc) {
332                s_usage(argv[0]);
333             }
334             /* manual strcat() */
335             for (i = 0; i < 255; i++) {
336                if (mullog[i] == '\0') {
337                   break;
338                }
339             }
340             for (j = 0; i < 255; j++, i++) {
341                mullog[i] = argv[opt][j];
342                if (argv[opt][j] == '\0') {
343                   break;
344                }
345             }
346             for (i = 0; i < 255; i++) {
347                if (sqrlog[i] == '\0') {
348                   break;
349                }
350             }
351             for (j = 0; i < 255; j++, i++) {
352                sqrlog[i] = argv[opt][j];
353                if (argv[opt][j] == '\0') {
354                   break;
355                }
356             }
357             break;
358          case 'b':
359             args.bncore = 1;
360             break;
361          case 't':
362             args.terse = 1;
363             break;
364          case 'S':
365             opt++;
366             if (opt >= argc) {
367                s_usage(argv[0]);
368             }
369             str = argv[opt];
370             errno = 0;
371             seed = (uint64_t)s_strtol(argv[opt], NULL, "No seed given?\n");
372             break;
373          case 'L':
374             opt++;
375             if (opt >= argc) {
376                s_usage(argv[0]);
377             }
378             s_stabilization_extra = (int)s_strtol(argv[opt], NULL, "No value for option \"-L\"given");
379             break;
380          case 'o':
381             opt++;
382             if (opt >= argc) {
383                s_usage(argv[0]);
384             }
385             s_offset = (int)s_strtol(argv[opt], NULL, "No value for the offset given");
386             break;
387          case 'r':
388             opt++;
389             if (opt >= argc) {
390                s_usage(argv[0]);
391             }
392             s_number_of_test_loops = (int)s_strtol(argv[opt], NULL, "No value for the number of rounds given");
393             break;
394
395          case 'M':
396             opt++;
397             if (opt >= argc) {
398                s_usage(argv[0]);
399             }
400             args.upper_limit_print = (int)s_strtol(argv[opt], NULL, "No value for the upper limit of T-C tests given");
401             break;
402          case 'm':
403             opt++;
404             if (opt >= argc) {
405                s_usage(argv[0]);
406             }
407             args.increment_print = (int)s_strtol(argv[opt], NULL, "No value for the increment for the T-C tests given");
408             break;
409          case 's':
410             printpreset = 1;
411             args.print = 1;
412             opt++;
413             if (opt >= argc) {
414                s_usage(argv[0]);
415             }
416             str = argv[opt];
417             KARATSUBA_MUL_CUTOFF = (int)s_strtol(str, &endptr, "[1/4] No value for KARATSUBA_MUL_CUTOFF given");
418             str = endptr + 1;
419             KARATSUBA_SQR_CUTOFF = (int)s_strtol(str, &endptr, "[2/4] No value for KARATSUBA_SQR_CUTOFF given");
420             str = endptr + 1;
421             TOOM_MUL_CUTOFF = (int)s_strtol(str, &endptr, "[3/4] No value for TOOM_MUL_CUTOFF given");
422             str = endptr + 1;
423             TOOM_SQR_CUTOFF = (int)s_strtol(str, &endptr, "[4/4] No value for TOOM_SQR_CUTOFF given");
424             break;
425          case 'h':
426             s_exit_code = EXIT_SUCCESS;
427          /* FALLTHROUGH */
428          default:
429             s_usage(argv[0]);
430          }
431       }
432    }
433
434    /*
435      mp_rand uses the cryptographically secure
436      source of the OS by default. That is too expensive, too slow and
437      most important for a benchmark: it is not repeatable.
438    */
439    s_mp_rand_jenkins_init(seed);
440    mp_rand_source(s_mp_rand_jenkins);
441
442    get_cutoffs(&orig);
443
444    updated = max_cutoffs;
445    if ((args.bncore == 0) && (printpreset == 0)) {
446       struct {
447          const char *name;
448          int *cutoff, *update;
449          uint64_t (*fn)(int);
450       } test[] = {
451 #define T_MUL_SQR(n, o, f)  { #n, &o##_CUTOFF, &(updated.o), MP_HAS(S_MP_##o) ? f : NULL }
452          /*
453             The influence of the Comba multiplication cannot be
454             eradicated programmatically. It depends on the size
455             of the macro MP_WPARRAY in tommath.h which needs to
456             be changed manually (to 0 (zero)).
457           */
458          T_MUL_SQR("Karatsuba multiplication", KARATSUBA_MUL, s_time_mul),
459          T_MUL_SQR("Karatsuba squaring", KARATSUBA_SQR, s_time_sqr),
460          T_MUL_SQR("Toom-Cook 3-way multiplying", TOOM_MUL, s_time_mul),
461          T_MUL_SQR("Toom-Cook 3-way squaring", TOOM_SQR, s_time_sqr),
462 #undef T_MUL_SQR
463       };
464       /* Turn all limits from bncore.c to the max */
465       set_cutoffs(&max_cutoffs);
466       for (n = 0; n < sizeof(test)/sizeof(test[0]); ++n) {
467          if (test[n].fn) {
468             s_run(test[n].name, test[n].fn, test[n].cutoff);
469             *test[n].update = *test[n].cutoff;
470             *test[n].cutoff = INT_MAX;
471          }
472       }
473    }
474    if (args.terse == 1) {
475       printf("%d %d %d %d\n",
476              updated.KARATSUBA_MUL,
477              updated.KARATSUBA_SQR,
478              updated.TOOM_MUL,
479              updated.TOOM_SQR);
480    } else {
481       printf("KARATSUBA_MUL_CUTOFF = %d\n", updated.KARATSUBA_MUL);
482       printf("KARATSUBA_SQR_CUTOFF = %d\n", updated.KARATSUBA_SQR);
483       printf("TOOM_MUL_CUTOFF = %d\n", updated.TOOM_MUL);
484       printf("TOOM_SQR_CUTOFF = %d\n", updated.TOOM_SQR);
485    }
486
487    if (args.print == 1) {
488       printf("Printing data for graphing to \"%s\" and \"%s\"\n",mullog, sqrlog);
489
490       multiplying = fopen(mullog, "w+");
491       if (multiplying == NULL) {
492          fprintf(stderr, "Opening file \"%s\" failed\n", mullog);
493          exit(EXIT_FAILURE);
494       }
495
496       squaring = fopen(sqrlog, "w+");
497       if (squaring == NULL) {
498          fprintf(stderr, "Opening file \"%s\" failed\n",sqrlog);
499          exit(EXIT_FAILURE);
500       }
501
502       for (x = 8; x < args.upper_limit_print; x += args.increment_print) {
503          set_cutoffs(&max_cutoffs);
504          t1 = s_time_mul(x);
505          set_cutoffs(&orig);
506          t2 = s_time_mul(x);
507          fprintf(multiplying, "%d: %9"PRIu64" %9"PRIu64", %9"PRIi64"\n", x, t1, t2, (int64_t)t2 - (int64_t)t1);
508          fflush(multiplying);
509          if (args.verbose == 1) {
510             printf("MUL %d: %9"PRIu64" %9"PRIu64", %9"PRIi64"\n", x, t1, t2, (int64_t)t2 - (int64_t)t1);
511             fflush(stdout);
512          }
513          set_cutoffs(&max_cutoffs);
514          t1 = s_time_sqr(x);
515          set_cutoffs(&orig);
516          t2 = s_time_sqr(x);
517          fprintf(squaring,"%d: %9"PRIu64" %9"PRIu64", %9"PRIi64"\n", x, t1, t2, (int64_t)t2 - (int64_t)t1);
518          fflush(squaring);
519          if (args.verbose == 1) {
520             printf("SQR %d: %9"PRIu64" %9"PRIu64", %9"PRIi64"\n", x, t1, t2, (int64_t)t2 - (int64_t)t1);
521             fflush(stdout);
522          }
523       }
524       printf("Finished. Data for graphing in \"%s\" and \"%s\"\n",mullog, sqrlog);
525       if (args.verbose == 1) {
526          set_cutoffs(&orig);
527          if (args.terse == 1) {
528             printf("%d %d %d %d\n",
529                    KARATSUBA_MUL_CUTOFF,
530                    KARATSUBA_SQR_CUTOFF,
531                    TOOM_MUL_CUTOFF,
532                    TOOM_SQR_CUTOFF);
533          } else {
534             printf("KARATSUBA_MUL_CUTOFF = %d\n", KARATSUBA_MUL_CUTOFF);
535             printf("KARATSUBA_SQR_CUTOFF = %d\n", KARATSUBA_SQR_CUTOFF);
536             printf("TOOM_MUL_CUTOFF = %d\n", TOOM_MUL_CUTOFF);
537             printf("TOOM_SQR_CUTOFF = %d\n", TOOM_SQR_CUTOFF);
538          }
539       }
540    }
541    exit(EXIT_SUCCESS);
542 }