r23456: Update Samba4 to current lorikeet-heimdal.
[ab/samba.git/.git] / source4 / heimdal / lib / asn1 / gen_decode.c
1 /*
2  * Copyright (c) 1997 - 2006 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden). 
4  * All rights reserved. 
5  *
6  * Redistribution and use in source and binary forms, with or without 
7  * modification, are permitted provided that the following conditions 
8  * are met: 
9  *
10  * 1. Redistributions of source code must retain the above copyright 
11  *    notice, this list of conditions and the following disclaimer. 
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright 
14  *    notice, this list of conditions and the following disclaimer in the 
15  *    documentation and/or other materials provided with the distribution. 
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors 
18  *    may be used to endorse or promote products derived from this software 
19  *    without specific prior written permission. 
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND 
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE 
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 
31  * SUCH DAMAGE. 
32  */
33
34 #include "gen_locl.h"
35 #include "lex.h"
36
37 RCSID("$Id: gen_decode.c 19572 2006-12-29 17:30:32Z lha $");
38
39 static void
40 decode_primitive (const char *typename, const char *name, const char *forwstr)
41 {
42 #if 0
43     fprintf (codefile,
44              "e = decode_%s(p, len, %s, &l);\n"
45              "%s;\n",
46              typename,
47              name,
48              forwstr);
49 #else
50     fprintf (codefile,
51              "e = der_get_%s(p, len, %s, &l);\n"
52              "if(e) %s;\np += l; len -= l; ret += l;\n",
53              typename,
54              name,
55              forwstr);
56 #endif
57 }
58
59 static int
60 is_primitive_type(int type)
61 {
62     switch(type) {
63     case TInteger:
64     case TBoolean:
65     case TOctetString:
66     case TBitString:
67     case TEnumerated:
68     case TGeneralizedTime:
69     case TGeneralString:
70     case TOID:
71     case TUTCTime:
72     case TUTF8String:
73     case TPrintableString:
74     case TIA5String:
75     case TBMPString:
76     case TUniversalString:
77     case TVisibleString:
78     case TNull:
79         return 1;
80     default:
81         return 0;
82     }
83 }
84
85 static void
86 find_tag (const Type *t,
87           Der_class *cl, Der_type *ty, unsigned *tag)
88 {
89     switch (t->type) {
90     case TBitString:
91         *cl  = ASN1_C_UNIV;
92         *ty  = PRIM;
93         *tag = UT_BitString;
94         break;
95     case TBoolean:
96         *cl  = ASN1_C_UNIV;
97         *ty  = PRIM;
98         *tag = UT_Boolean;
99         break;
100     case TChoice: 
101         errx(1, "Cannot have recursive CHOICE");
102     case TEnumerated:
103         *cl  = ASN1_C_UNIV;
104         *ty  = PRIM;
105         *tag = UT_Enumerated;
106         break;
107     case TGeneralString: 
108         *cl  = ASN1_C_UNIV;
109         *ty  = PRIM;
110         *tag = UT_GeneralString;
111         break;
112     case TGeneralizedTime: 
113         *cl  = ASN1_C_UNIV;
114         *ty  = PRIM;
115         *tag = UT_GeneralizedTime;
116         break;
117     case TIA5String:
118         *cl  = ASN1_C_UNIV;
119         *ty  = PRIM;
120         *tag = UT_IA5String;
121         break;
122     case TInteger: 
123         *cl  = ASN1_C_UNIV;
124         *ty  = PRIM;
125         *tag = UT_Integer;
126         break;
127     case TNull:
128         *cl  = ASN1_C_UNIV;
129         *ty  = PRIM;
130         *tag = UT_Null;
131         break;
132     case TOID: 
133         *cl  = ASN1_C_UNIV;
134         *ty  = PRIM;
135         *tag = UT_OID;
136         break;
137     case TOctetString: 
138         *cl  = ASN1_C_UNIV;
139         *ty  = PRIM;
140         *tag = UT_OctetString;
141         break;
142     case TPrintableString:
143         *cl  = ASN1_C_UNIV;
144         *ty  = PRIM;
145         *tag = UT_PrintableString;
146         break;
147     case TSequence: 
148     case TSequenceOf:
149         *cl  = ASN1_C_UNIV;
150         *ty  = CONS;
151         *tag = UT_Sequence;
152         break;
153     case TSet: 
154     case TSetOf:
155         *cl  = ASN1_C_UNIV;
156         *ty  = CONS;
157         *tag = UT_Set;
158         break;
159     case TTag: 
160         *cl  = t->tag.tagclass;
161         *ty  = is_primitive_type(t->subtype->type) ? PRIM : CONS;
162         *tag = t->tag.tagvalue;
163         break;
164     case TType: 
165         if ((t->symbol->stype == Stype && t->symbol->type == NULL)
166             || t->symbol->stype == SUndefined) {
167             error_message("%s is imported or still undefined, "
168                           " can't generate tag checking data in CHOICE "
169                           "without this information",
170                           t->symbol->name);
171             exit(1);
172         }
173         find_tag(t->symbol->type, cl, ty, tag);
174         return;
175     case TUTCTime: 
176         *cl  = ASN1_C_UNIV;
177         *ty  = PRIM;
178         *tag = UT_UTCTime;
179         break;
180     case TUTF8String:
181         *cl  = ASN1_C_UNIV;
182         *ty  = PRIM;
183         *tag = UT_UTF8String;
184         break;
185     case TBMPString:
186         *cl  = ASN1_C_UNIV;
187         *ty  = PRIM;
188         *tag = UT_BMPString;
189         break;
190     case TUniversalString:
191         *cl  = ASN1_C_UNIV;
192         *ty  = PRIM;
193         *tag = UT_UniversalString;
194         break;
195     case TVisibleString:
196         *cl  = ASN1_C_UNIV;
197         *ty  = PRIM;
198         *tag = UT_VisibleString;
199         break;
200     default:
201         abort();
202     }
203 }
204
205 static int
206 decode_type (const char *name, const Type *t, int optional, 
207              const char *forwstr, const char *tmpstr)
208 {
209     switch (t->type) {
210     case TType: {
211         if (optional)
212             fprintf(codefile, 
213                     "%s = calloc(1, sizeof(*%s));\n"
214                     "if (%s == NULL) %s;\n",
215                     name, name, name, forwstr);
216         fprintf (codefile,
217                  "e = decode_%s(p, len, %s, &l);\n",
218                  t->symbol->gen_name, name);
219         if (optional) {
220             fprintf (codefile,
221                      "if(e) {\n"
222                      "free(%s);\n"
223                      "%s = NULL;\n"
224                      "} else {\n"
225                      "p += l; len -= l; ret += l;\n"
226                      "}\n",
227                      name, name);
228         } else {
229             fprintf (codefile,
230                      "if(e) %s;\n",
231                      forwstr);
232             fprintf (codefile,
233                      "p += l; len -= l; ret += l;\n");
234         }
235         break;
236     }
237     case TInteger:
238         if(t->members) {
239             char *s;
240             asprintf(&s, "(int*)%s", name);
241             if (s == NULL)
242                 errx (1, "out of memory");
243             decode_primitive ("integer", s, forwstr);
244             free(s);
245         } else if (t->range == NULL) {
246             decode_primitive ("heim_integer", name, forwstr);
247         } else if (t->range->min == INT_MIN && t->range->max == INT_MAX) {
248             decode_primitive ("integer", name, forwstr);
249         } else if (t->range->min == 0 && t->range->max == UINT_MAX) {
250             decode_primitive ("unsigned", name, forwstr);
251         } else if (t->range->min == 0 && t->range->max == INT_MAX) {
252             decode_primitive ("unsigned", name, forwstr);
253         } else
254             errx(1, "%s: unsupported range %d -> %d", 
255                  name, t->range->min, t->range->max);
256         break;
257     case TBoolean:
258       decode_primitive ("boolean", name, forwstr);
259       break;
260     case TEnumerated:
261         decode_primitive ("enumerated", name, forwstr);
262         break;
263     case TOctetString:
264         decode_primitive ("octet_string", name, forwstr);
265         break;
266     case TBitString: {
267         Member *m;
268         int pos = 0;
269
270         if (ASN1_TAILQ_EMPTY(t->members)) {
271             decode_primitive ("bit_string", name, forwstr);
272             break;
273         }
274         fprintf(codefile,
275                 "if (len < 1) return ASN1_OVERRUN;\n"
276                 "p++; len--; ret++;\n");
277         fprintf(codefile,
278                 "do {\n"
279                 "if (len < 1) break;\n");
280         ASN1_TAILQ_FOREACH(m, t->members, members) {
281             while (m->val / 8 > pos / 8) {
282                 fprintf (codefile,
283                          "p++; len--; ret++;\n"
284                          "if (len < 1) break;\n");
285                 pos += 8;
286             }
287             fprintf (codefile,
288                      "(%s)->%s = (*p >> %d) & 1;\n",
289                      name, m->gen_name, 7 - m->val % 8);
290         }
291         fprintf(codefile,
292                 "} while(0);\n");
293         fprintf (codefile,
294                  "p += len; ret += len;\n");
295         break;
296     }
297     case TSequence: {
298         Member *m;
299
300         if (t->members == NULL)
301             break;
302
303         ASN1_TAILQ_FOREACH(m, t->members, members) {
304             char *s;
305
306             if (m->ellipsis)
307                 continue;
308
309             asprintf (&s, "%s(%s)->%s", m->optional ? "" : "&",
310                       name, m->gen_name);
311             if (s == NULL)
312                 errx(1, "malloc");
313             decode_type (s, m->type, m->optional, forwstr, m->gen_name);
314             free (s);
315         }
316         
317         break;
318     }
319     case TSet: {
320         Member *m;
321         unsigned int memno;
322
323         if(t->members == NULL)
324             break;
325
326         fprintf(codefile, "{\n");
327         fprintf(codefile, "unsigned int members = 0;\n");
328         fprintf(codefile, "while(len > 0) {\n");
329         fprintf(codefile, 
330                 "Der_class class;\n"
331                 "Der_type type;\n"
332                 "int tag;\n"
333                 "e = der_get_tag (p, len, &class, &type, &tag, NULL);\n"
334                 "if(e) %s;\n", forwstr);
335         fprintf(codefile, "switch (MAKE_TAG(class, type, tag)) {\n");
336         memno = 0;
337         ASN1_TAILQ_FOREACH(m, t->members, members) {
338             char *s;
339
340             assert(m->type->type == TTag);
341
342             fprintf(codefile, "case MAKE_TAG(%s, %s, %s):\n",
343                     classname(m->type->tag.tagclass),
344                     is_primitive_type(m->type->subtype->type) ? "PRIM" : "CONS",
345                     valuename(m->type->tag.tagclass, m->type->tag.tagvalue));
346
347             asprintf (&s, "%s(%s)->%s", m->optional ? "" : "&", name, m->gen_name);
348             if (s == NULL)
349                 errx(1, "malloc");
350             if(m->optional)
351                 fprintf(codefile, 
352                         "%s = calloc(1, sizeof(*%s));\n"
353                         "if (%s == NULL) { e = ENOMEM; %s; }\n",
354                         s, s, s, forwstr);
355             decode_type (s, m->type, 0, forwstr, m->gen_name);
356             free (s);
357
358             fprintf(codefile, "members |= (1 << %d);\n", memno);
359             memno++;
360             fprintf(codefile, "break;\n");
361         }
362         fprintf(codefile, 
363                 "default:\n"
364                 "return ASN1_MISPLACED_FIELD;\n"
365                 "break;\n");
366         fprintf(codefile, "}\n");
367         fprintf(codefile, "}\n");
368         memno = 0;
369         ASN1_TAILQ_FOREACH(m, t->members, members) {
370             char *s;
371
372             asprintf (&s, "%s->%s", name, m->gen_name);
373             if (s == NULL)
374                 errx(1, "malloc");
375             fprintf(codefile, "if((members & (1 << %d)) == 0)\n", memno);
376             if(m->optional)
377                 fprintf(codefile, "%s = NULL;\n", s);
378             else if(m->defval)
379                 gen_assign_defval(s, m->defval);
380             else
381                 fprintf(codefile, "return ASN1_MISSING_FIELD;\n");
382             free(s);
383             memno++;
384         }
385         fprintf(codefile, "}\n");
386         break;
387     }
388     case TSetOf:
389     case TSequenceOf: {
390         char *n;
391         char *sname;
392
393         fprintf (codefile,
394                  "{\n"
395                  "size_t %s_origlen = len;\n"
396                  "size_t %s_oldret = ret;\n"
397                  "void *%s_tmp;\n"
398                  "ret = 0;\n"
399                  "(%s)->len = 0;\n"
400                  "(%s)->val = NULL;\n"
401                  "while(ret < %s_origlen) {\n"
402                  "%s_tmp = realloc((%s)->val, "
403                  "    sizeof(*((%s)->val)) * ((%s)->len + 1));\n"
404                  "if (%s_tmp == NULL) { %s; }\n"
405                  "(%s)->val = %s_tmp;\n",
406                  tmpstr, tmpstr, tmpstr,
407                  name, name,
408                  tmpstr, tmpstr,
409                  name, name, name,
410                  tmpstr, forwstr, 
411                  name, tmpstr);
412
413         asprintf (&n, "&(%s)->val[(%s)->len]", name, name);
414         if (n == NULL)
415             errx(1, "malloc");
416         asprintf (&sname, "%s_s_of", tmpstr);
417         if (sname == NULL)
418             errx(1, "malloc");
419         decode_type (n, t->subtype, 0, forwstr, sname);
420         fprintf (codefile, 
421                  "(%s)->len++;\n"
422                  "len = %s_origlen - ret;\n"
423                  "}\n"
424                  "ret += %s_oldret;\n"
425                  "}\n",
426                  name,
427                  tmpstr, tmpstr);
428         free (n);
429         free (sname);
430         break;
431     }
432     case TGeneralizedTime:
433         decode_primitive ("generalized_time", name, forwstr);
434         break;
435     case TGeneralString:
436         decode_primitive ("general_string", name, forwstr);
437         break;
438     case TTag:{
439         char *tname;
440
441         fprintf(codefile, 
442                 "{\n"
443                 "size_t %s_datalen, %s_oldlen;\n",
444                 tmpstr, tmpstr);
445         if(dce_fix)
446             fprintf(codefile, 
447                     "int dce_fix;\n");
448         fprintf(codefile, "e = der_match_tag_and_length(p, len, %s, %s, %s, "
449                 "&%s_datalen, &l);\n",
450                 classname(t->tag.tagclass),
451                 is_primitive_type(t->subtype->type) ? "PRIM" : "CONS",
452                 valuename(t->tag.tagclass, t->tag.tagvalue),
453                 tmpstr);
454         if(optional) {
455             fprintf(codefile, 
456                     "if(e) {\n"
457                     "%s = NULL;\n"
458                     "} else {\n"
459                      "%s = calloc(1, sizeof(*%s));\n"
460                      "if (%s == NULL) { e = ENOMEM; %s; }\n",
461                      name, name, name, name, forwstr);
462         } else {
463             fprintf(codefile, "if(e) %s;\n", forwstr);
464         }
465         fprintf (codefile,
466                  "p += l; len -= l; ret += l;\n"
467                  "%s_oldlen = len;\n",
468                  tmpstr);
469         if(dce_fix)
470             fprintf (codefile,
471                      "if((dce_fix = _heim_fix_dce(%s_datalen, &len)) < 0)\n"
472                      "{ e = ASN1_BAD_FORMAT; %s; }\n",
473                      tmpstr, forwstr);
474         else
475             fprintf(codefile, 
476                     "if (%s_datalen > len) { e = ASN1_OVERRUN; %s; }\n"
477                     "len = %s_datalen;\n", tmpstr, forwstr, tmpstr);
478         asprintf (&tname, "%s_Tag", tmpstr);
479         if (tname == NULL)
480             errx(1, "malloc");
481         decode_type (name, t->subtype, 0, forwstr, tname);
482         if(dce_fix)
483             fprintf(codefile,
484                     "if(dce_fix){\n"
485                     "e = der_match_tag_and_length (p, len, "
486                     "(Der_class)0,(Der_type)0, UT_EndOfContent, "
487                     "&%s_datalen, &l);\n"
488                     "if(e) %s;\np += l; len -= l; ret += l;\n"
489                     "} else \n", tmpstr, forwstr);
490         fprintf(codefile, 
491                 "len = %s_oldlen - %s_datalen;\n",
492                 tmpstr, tmpstr);
493         if(optional)
494             fprintf(codefile, 
495                     "}\n");
496         fprintf(codefile, 
497                 "}\n");
498         free(tname);
499         break;
500     }
501     case TChoice: {
502         Member *m, *have_ellipsis = NULL;
503         const char *els = "";
504
505         if (t->members == NULL)
506             break;
507
508         ASN1_TAILQ_FOREACH(m, t->members, members) {
509             const Type *tt = m->type;
510             char *s;
511             Der_class cl;
512             Der_type  ty;
513             unsigned  tag;
514             
515             if (m->ellipsis) {
516                 have_ellipsis = m;
517                 continue;
518             }
519
520             find_tag(tt, &cl, &ty, &tag);
521
522             fprintf(codefile,
523                     "%sif (der_match_tag(p, len, %s, %s, %s, NULL) == 0) {\n",
524                     els,
525                     classname(cl),
526                     ty ? "CONS" : "PRIM",
527                     valuename(cl, tag));
528             asprintf (&s, "%s(%s)->u.%s", m->optional ? "" : "&",
529                       name, m->gen_name);
530             if (s == NULL)
531                 errx(1, "malloc");
532             decode_type (s, m->type, m->optional, forwstr, m->gen_name);
533             fprintf(codefile,
534                     "(%s)->element = %s;\n",
535                     name, m->label);
536             free(s);
537             fprintf(codefile,
538                     "}\n");
539             els = "else ";
540         }
541         if (have_ellipsis) {
542             fprintf(codefile,
543                     "else {\n"
544                     "(%s)->u.%s.data = calloc(1, len);\n"
545                     "if ((%s)->u.%s.data == NULL) {\n"
546                     "e = ENOMEM; %s;\n"
547                     "}\n"
548                     "(%s)->u.%s.length = len;\n"
549                     "memcpy((%s)->u.%s.data, p, len);\n"
550                     "(%s)->element = %s;\n"
551                     "p += len;\n"
552                     "ret += len;\n"
553                     "len -= len;\n"
554                     "}\n",
555                     name, have_ellipsis->gen_name,
556                     name, have_ellipsis->gen_name,
557                     forwstr, 
558                     name, have_ellipsis->gen_name,
559                     name, have_ellipsis->gen_name,
560                     name, have_ellipsis->label);
561         } else {
562             fprintf(codefile,
563                     "else {\n"
564                     "e = ASN1_PARSE_ERROR;\n"
565                     "%s;\n"
566                     "}\n",
567                     forwstr);
568         }
569         break;
570     }
571     case TUTCTime:
572         decode_primitive ("utctime", name, forwstr);
573         break;
574     case TUTF8String:
575         decode_primitive ("utf8string", name, forwstr);
576         break;
577     case TPrintableString:
578         decode_primitive ("printable_string", name, forwstr);
579         break;
580     case TIA5String:
581         decode_primitive ("ia5_string", name, forwstr);
582         break;
583     case TBMPString:
584         decode_primitive ("bmp_string", name, forwstr);
585         break;
586     case TUniversalString:
587         decode_primitive ("universal_string", name, forwstr);
588         break;
589     case TVisibleString:
590         decode_primitive ("visible_string", name, forwstr);
591         break;
592     case TNull:
593         fprintf (codefile, "/* NULL */\n");
594         break;
595     case TOID:
596         decode_primitive ("oid", name, forwstr);
597         break;
598     default :
599         abort ();
600     }
601     return 0;
602 }
603
604 void
605 generate_type_decode (const Symbol *s)
606 {
607     int preserve = preserve_type(s->name) ? TRUE : FALSE;
608
609     fprintf (headerfile,
610              "int    "
611              "decode_%s(const unsigned char *, size_t, %s *, size_t *);\n",
612              s->gen_name, s->gen_name);
613
614     fprintf (codefile, "int\n"
615              "decode_%s(const unsigned char *p,"
616              " size_t len, %s *data, size_t *size)\n"
617              "{\n",
618              s->gen_name, s->gen_name);
619
620     switch (s->type->type) {
621     case TInteger:
622     case TBoolean:
623     case TOctetString:
624     case TOID:
625     case TGeneralizedTime:
626     case TGeneralString:
627     case TUTF8String:
628     case TPrintableString:
629     case TIA5String:
630     case TBMPString:
631     case TUniversalString:
632     case TVisibleString:
633     case TUTCTime:
634     case TNull:
635     case TEnumerated:
636     case TBitString:
637     case TSequence:
638     case TSequenceOf:
639     case TSet:
640     case TSetOf:
641     case TTag:
642     case TType:
643     case TChoice:
644         fprintf (codefile,
645                  "size_t ret = 0;\n"
646                  "size_t l;\n"
647                  "int e;\n");
648         if (preserve)
649             fprintf (codefile, "const unsigned char *begin = p;\n");
650
651         fprintf (codefile, "\n");
652         fprintf (codefile, "memset(data, 0, sizeof(*data));\n"); /* hack to avoid `unused variable' */
653
654         decode_type ("data", s->type, 0, "goto fail", "Top");
655         if (preserve)
656             fprintf (codefile,
657                      "data->_save.data = calloc(1, ret);\n"
658                      "if (data->_save.data == NULL) { \n"
659                      "e = ENOMEM; goto fail; \n"
660                      "}\n"
661                      "data->_save.length = ret;\n"
662                      "memcpy(data->_save.data, begin, ret);\n");
663         fprintf (codefile, 
664                  "if(size) *size = ret;\n"
665                  "return 0;\n");
666         fprintf (codefile,
667                  "fail:\n"
668                  "free_%s(data);\n"
669                  "return e;\n",
670                  s->gen_name);
671         break;
672     default:
673         abort ();
674     }
675     fprintf (codefile, "}\n\n");
676 }