r26316: Use contexts for conversion functions.
[jelmer/samba4-debian.git] / source / auth / ntlmssp / ntlmssp_parse.c
1 /* 
2    Unix SMB/CIFS implementation.
3    simple kerberos5/SPNEGO routines
4    Copyright (C) Andrew Tridgell 2001
5    Copyright (C) Jim McDonough <jmcd@us.ibm.com> 2002
6    Copyright (C) Andrew Bartlett 2002-2003
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 3 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 */
21
22 #include "includes.h"
23 #include "pstring.h"
24
25 /*
26   this is a tiny msrpc packet generator. I am only using this to
27   avoid tying this code to a particular varient of our rpc code. This
28   generator is not general enough for all our rpc needs, its just
29   enough for the spnego/ntlmssp code
30
31   format specifiers are:
32
33   U = unicode string (input is unix string)
34   a = address (input is char *unix_string)
35       (1 byte type, 1 byte length, unicode/ASCII string, all inline)
36   A = ASCII string (input is unix string)
37   B = data blob (pointer + length)
38   b = data blob in header (pointer + length)
39   D
40   d = word (4 bytes)
41   C = constant ascii string
42  */
43 bool msrpc_gen(TALLOC_CTX *mem_ctx, DATA_BLOB *blob,
44                const char *format, ...)
45 {
46         int i;
47         ssize_t n;
48         va_list ap;
49         char *s;
50         uint8_t *b;
51         int head_size=0, data_size=0;
52         int head_ofs, data_ofs;
53         int *intargs;
54
55         DATA_BLOB *pointers;
56
57         pointers = talloc_array(mem_ctx, DATA_BLOB, strlen(format));
58         intargs = talloc_array(pointers, int, strlen(format));
59
60         /* first scan the format to work out the header and body size */
61         va_start(ap, format);
62         for (i=0; format[i]; i++) {
63                 switch (format[i]) {
64                 case 'U':
65                         s = va_arg(ap, char *);
66                         head_size += 8;
67                         n = push_ucs2_talloc(pointers, global_smb_iconv_convenience, (void **)&pointers[i].data, s);
68                         if (n == -1) {
69                                 return false;
70                         }
71                         pointers[i].length = n;
72                         pointers[i].length -= 2;
73                         data_size += pointers[i].length;
74                         break;
75                 case 'A':
76                         s = va_arg(ap, char *);
77                         head_size += 8;
78                         n = push_ascii_talloc(pointers, global_smb_iconv_convenience, (char **)&pointers[i].data, s);
79                         if (n == -1) {
80                                 return false;
81                         }
82                         pointers[i].length = n;
83                         pointers[i].length -= 1;
84                         data_size += pointers[i].length;
85                         break;
86                 case 'a':
87                         n = va_arg(ap, int);
88                         intargs[i] = n;
89                         s = va_arg(ap, char *);
90                         n = push_ucs2_talloc(pointers, global_smb_iconv_convenience, (void **)&pointers[i].data, s);
91                         if (n == -1) {
92                                 return false;
93                         }
94                         pointers[i].length = n;
95                         pointers[i].length -= 2;
96                         data_size += pointers[i].length + 4;
97                         break;
98                 case 'B':
99                         b = va_arg(ap, uint8_t *);
100                         head_size += 8;
101                         pointers[i].data = b;
102                         pointers[i].length = va_arg(ap, int);
103                         data_size += pointers[i].length;
104                         break;
105                 case 'b':
106                         b = va_arg(ap, uint8_t *);
107                         pointers[i].data = b;
108                         pointers[i].length = va_arg(ap, int);
109                         head_size += pointers[i].length;
110                         break;
111                 case 'd':
112                         n = va_arg(ap, int);
113                         intargs[i] = n;
114                         head_size += 4;
115                         break;
116                 case 'C':
117                         s = va_arg(ap, char *);
118                         pointers[i].data = (uint8_t *)s;
119                         pointers[i].length = strlen(s)+1;
120                         head_size += pointers[i].length;
121                         break;
122                 }
123         }
124         va_end(ap);
125
126         /* allocate the space, then scan the format again to fill in the values */
127         *blob = data_blob_talloc(mem_ctx, NULL, head_size + data_size);
128
129         head_ofs = 0;
130         data_ofs = head_size;
131
132         va_start(ap, format);
133         for (i=0; format[i]; i++) {
134                 switch (format[i]) {
135                 case 'U':
136                 case 'A':
137                 case 'B':
138                         n = pointers[i].length;
139                         SSVAL(blob->data, head_ofs, n); head_ofs += 2;
140                         SSVAL(blob->data, head_ofs, n); head_ofs += 2;
141                         SIVAL(blob->data, head_ofs, data_ofs); head_ofs += 4;
142                         if (pointers[i].data && n) /* don't follow null pointers... */
143                                 memcpy(blob->data+data_ofs, pointers[i].data, n);
144                         data_ofs += n;
145                         break;
146                 case 'a':
147                         n = intargs[i];
148                         SSVAL(blob->data, data_ofs, n); data_ofs += 2;
149
150                         n = pointers[i].length;
151                         SSVAL(blob->data, data_ofs, n); data_ofs += 2;
152                         if (n >= 0) {
153                                 memcpy(blob->data+data_ofs, pointers[i].data, n);
154                         }
155                         data_ofs += n;
156                         break;
157                 case 'd':
158                         n = intargs[i];
159                         SIVAL(blob->data, head_ofs, n); 
160                         head_ofs += 4;
161                         break;
162                 case 'b':
163                         n = pointers[i].length;
164                         memcpy(blob->data + head_ofs, pointers[i].data, n);
165                         head_ofs += n;
166                         break;
167                 case 'C':
168                         n = pointers[i].length;
169                         memcpy(blob->data + head_ofs, pointers[i].data, n);
170                         head_ofs += n;
171                         break;
172                 }
173         }
174         va_end(ap);
175         
176         talloc_free(pointers);
177
178         return true;
179 }
180
181
182 /* a helpful macro to avoid running over the end of our blob */
183 #define NEED_DATA(amount) \
184 if ((head_ofs + amount) > blob->length) { \
185         return false; \
186 }
187
188 /*
189   this is a tiny msrpc packet parser. This the the partner of msrpc_gen
190
191   format specifiers are:
192
193   U = unicode string (output is unix string)
194   A = ascii string
195   B = data blob
196   b = data blob in header
197   d = word (4 bytes)
198   C = constant ascii string
199  */
200
201 bool msrpc_parse(TALLOC_CTX *mem_ctx, const DATA_BLOB *blob,
202                  const char *format, ...)
203 {
204         int i;
205         va_list ap;
206         const char **ps, *s;
207         DATA_BLOB *b;
208         size_t head_ofs = 0;
209         uint16_t len1, len2;
210         uint32_t ptr;
211         uint32_t *v;
212         pstring p;
213
214         va_start(ap, format);
215         for (i=0; format[i]; i++) {
216                 switch (format[i]) {
217                 case 'U':
218                         NEED_DATA(8);
219                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
220                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
221                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
222
223                         ps = (const char **)va_arg(ap, char **);
224                         if (len1 == 0 && len2 == 0) {
225                                 *ps = "";
226                         } else {
227                                 /* make sure its in the right format - be strict */
228                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
229                                         return false;
230                                 }
231                                 if (len1 & 1) {
232                                         /* if odd length and unicode */
233                                         return false;
234                                 }
235                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
236                                         return false;
237
238                                 if (0 < len1) {
239                                         pull_string(global_smb_iconv_convenience, p, blob->data + ptr, sizeof(p), 
240                                                     len1, STR_UNICODE|STR_NOALIGN);
241                                         (*ps) = talloc_strdup(mem_ctx, p);
242                                         if (!(*ps)) {
243                                                 return false;
244                                         }
245                                 } else {
246                                         (*ps) = "";
247                                 }
248                         }
249                         break;
250                 case 'A':
251                         NEED_DATA(8);
252                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
253                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
254                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
255
256                         ps = (const char **)va_arg(ap, char **);
257                         /* make sure its in the right format - be strict */
258                         if (len1 == 0 && len2 == 0) {
259                                 *ps = "";
260                         } else {
261                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
262                                         return false;
263                                 }
264
265                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
266                                         return false;   
267
268                                 if (0 < len1) {
269                                         pull_string(global_smb_iconv_convenience, p, blob->data + ptr, sizeof(p), 
270                                                     len1, STR_ASCII|STR_NOALIGN);
271                                         (*ps) = talloc_strdup(mem_ctx, p);
272                                         if (!(*ps)) {
273                                                 return false;
274                                         }
275                                 } else {
276                                         (*ps) = "";
277                                 }
278                         }
279                         break;
280                 case 'B':
281                         NEED_DATA(8);
282                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
283                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
284                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
285
286                         b = (DATA_BLOB *)va_arg(ap, void *);
287                         if (len1 == 0 && len2 == 0) {
288                                 *b = data_blob_talloc(mem_ctx, NULL, 0);
289                         } else {
290                                 /* make sure its in the right format - be strict */
291                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
292                                         return false;
293                                 }
294
295                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
296                                         return false;   
297                         
298                                 *b = data_blob_talloc(mem_ctx, blob->data + ptr, len1);
299                         }
300                         break;
301                 case 'b':
302                         b = (DATA_BLOB *)va_arg(ap, void *);
303                         len1 = va_arg(ap, uint_t);
304                         /* make sure its in the right format - be strict */
305                         NEED_DATA(len1);
306                         if (blob->data + head_ofs < (uint8_t *)head_ofs || blob->data + head_ofs < blob->data)
307                                 return false;   
308                         
309                         *b = data_blob_talloc(mem_ctx, blob->data + head_ofs, len1);
310                         head_ofs += len1;
311                         break;
312                 case 'd':
313                         v = va_arg(ap, uint32_t *);
314                         NEED_DATA(4);
315                         *v = IVAL(blob->data, head_ofs); head_ofs += 4;
316                         break;
317                 case 'C':
318                         s = va_arg(ap, char *);
319
320                         if (blob->data + head_ofs < (uint8_t *)head_ofs || blob->data + head_ofs < blob->data)
321                                 return false;   
322         
323                         head_ofs += pull_string(global_smb_iconv_convenience, p, blob->data+head_ofs, sizeof(p), 
324                                                 blob->length - head_ofs, 
325                                                 STR_ASCII|STR_TERMINATE);
326                         if (strcmp(s, p) != 0) {
327                                 return false;
328                         }
329                         break;
330                 }
331         }
332         va_end(ap);
333
334         return true;
335 }