a713b9896d47e323b80724f586878fb9e70e5099
[jelmer/samba4-debian.git] / source / libcli / auth / ntlmssp_parse.c
1 /* 
2    Unix SMB/CIFS implementation.
3    simple kerberos5/SPNEGO routines
4    Copyright (C) Andrew Tridgell 2001
5    Copyright (C) Jim McDonough <jmcd@us.ibm.com> 2002
6    Copyright (C) Andrew Bartlett 2002-2003
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 #include "includes.h"
24
25 /*
26   this is a tiny msrpc packet generator. I am only using this to
27   avoid tying this code to a particular varient of our rpc code. This
28   generator is not general enough for all our rpc needs, its just
29   enough for the spnego/ntlmssp code
30
31   format specifiers are:
32
33   U = unicode string (input is unix string)
34   a = address (input is char *unix_string)
35       (1 byte type, 1 byte length, unicode/ASCII string, all inline)
36   A = ASCII string (input is unix string)
37   B = data blob (pointer + length)
38   b = data blob in header (pointer + length)
39   D
40   d = word (4 bytes)
41   C = constant ascii string
42  */
43 BOOL msrpc_gen(TALLOC_CTX *mem_ctx, DATA_BLOB *blob,
44                const char *format, ...)
45 {
46         int i;
47         ssize_t n;
48         va_list ap;
49         char *s;
50         uint8_t *b;
51         int head_size=0, data_size=0;
52         int head_ofs, data_ofs;
53         int *intargs;
54
55         DATA_BLOB *pointers;
56
57         pointers = talloc_array_p(mem_ctx, DATA_BLOB, strlen(format));
58         intargs = talloc_array_p(pointers, int, strlen(format));
59
60         /* first scan the format to work out the header and body size */
61         va_start(ap, format);
62         for (i=0; format[i]; i++) {
63                 switch (format[i]) {
64                 case 'U':
65                         s = va_arg(ap, char *);
66                         head_size += 8;
67                         n = push_ucs2_talloc(pointers, (void **)&pointers[i].data, s);
68                         if (n == -1) {
69                                 return False;
70                         }
71                         pointers[i].length = n;
72                         pointers[i].length -= 2;
73                         data_size += pointers[i].length;
74                         break;
75                 case 'A':
76                         s = va_arg(ap, char *);
77                         head_size += 8;
78                         n = push_ascii_talloc(pointers, (char **)&pointers[i].data, s);
79                         if (n == -1) {
80                                 return False;
81                         }
82                         pointers[i].length = n;
83                         pointers[i].length -= 1;
84                         data_size += pointers[i].length;
85                         break;
86                 case 'a':
87                         n = va_arg(ap, int);
88                         intargs[i] = n;
89                         s = va_arg(ap, char *);
90                         n = push_ucs2_talloc(pointers, (void **)&pointers[i].data, s);
91                         if (n == -1) {
92                                 return False;
93                         }
94                         pointers[i].length = n;
95                         pointers[i].length -= 2;
96                         data_size += pointers[i].length + 4;
97                         break;
98                 case 'B':
99                         b = va_arg(ap, uint8_t *);
100                         head_size += 8;
101                         pointers[i].data = b;
102                         pointers[i].length = va_arg(ap, int);
103                         data_size += pointers[i].length;
104                         break;
105                 case 'b':
106                         b = va_arg(ap, uint8_t *);
107                         pointers[i].data = b;
108                         pointers[i].length = va_arg(ap, int);
109                         head_size += pointers[i].length;
110                         break;
111                 case 'd':
112                         n = va_arg(ap, int);
113                         intargs[i] = n;
114                         head_size += 4;
115                         break;
116                 case 'C':
117                         s = va_arg(ap, char *);
118                         pointers[i].data = (uint8_t *)s;
119                         pointers[i].length = strlen(s)+1;
120                         head_size += pointers[i].length;
121                         break;
122                 }
123         }
124         va_end(ap);
125
126         /* allocate the space, then scan the format again to fill in the values */
127         *blob = data_blob_talloc(mem_ctx, NULL, head_size + data_size);
128
129         head_ofs = 0;
130         data_ofs = head_size;
131
132         va_start(ap, format);
133         for (i=0; format[i]; i++) {
134                 switch (format[i]) {
135                 case 'U':
136                 case 'A':
137                 case 'B':
138                         n = pointers[i].length;
139                         SSVAL(blob->data, head_ofs, n); head_ofs += 2;
140                         SSVAL(blob->data, head_ofs, n); head_ofs += 2;
141                         SIVAL(blob->data, head_ofs, data_ofs); head_ofs += 4;
142                         if (pointers[i].data && n) /* don't follow null pointers... */
143                                 memcpy(blob->data+data_ofs, pointers[i].data, n);
144                         data_ofs += n;
145                         break;
146                 case 'a':
147                         n = intargs[i];
148                         SSVAL(blob->data, data_ofs, n); data_ofs += 2;
149
150                         n = pointers[i].length;
151                         SSVAL(blob->data, data_ofs, n); data_ofs += 2;
152                         if (n >= 0) {
153                                 memcpy(blob->data+data_ofs, pointers[i].data, n);
154                         }
155                         data_ofs += n;
156                         break;
157                 case 'd':
158                         n = intargs[i];
159                         SIVAL(blob->data, head_ofs, n); 
160                         head_ofs += 4;
161                         break;
162                 case 'b':
163                         n = pointers[i].length;
164                         memcpy(blob->data + head_ofs, pointers[i].data, n);
165                         head_ofs += n;
166                         break;
167                 case 'C':
168                         n = pointers[i].length;
169                         memcpy(blob->data + head_ofs, pointers[i].data, n);
170                         head_ofs += n;
171                         break;
172                 }
173         }
174         va_end(ap);
175         
176         talloc_free(pointers);
177
178         return True;
179 }
180
181
182 /* a helpful macro to avoid running over the end of our blob */
183 #define NEED_DATA(amount) \
184 if ((head_ofs + amount) > blob->length) { \
185         return False; \
186 }
187
188 /*
189   this is a tiny msrpc packet parser. This the the partner of msrpc_gen
190
191   format specifiers are:
192
193   U = unicode string (output is unix string)
194   A = ascii string
195   B = data blob
196   b = data blob in header
197   d = word (4 bytes)
198   C = constant ascii string
199  */
200
201 BOOL msrpc_parse(TALLOC_CTX *mem_ctx, const DATA_BLOB *blob,
202                  const char *format, ...)
203 {
204         int i;
205         va_list ap;
206         const char **ps, *s;
207         DATA_BLOB *b;
208         size_t head_ofs = 0;
209         uint16_t len1, len2;
210         uint32_t ptr;
211         uint32_t *v;
212         pstring p;
213
214         va_start(ap, format);
215         for (i=0; format[i]; i++) {
216                 switch (format[i]) {
217                 case 'U':
218                         NEED_DATA(8);
219                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
220                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
221                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
222
223                         ps = (const char **)va_arg(ap, char **);
224                         if (len1 == 0 && len2 == 0) {
225                                 *ps = "";
226                         } else {
227                                 /* make sure its in the right format - be strict */
228                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
229                                         return False;
230                                 }
231                                 if (len1 & 1) {
232                                         /* if odd length and unicode */
233                                         return False;
234                                 }
235                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
236                                         return False;
237
238                                 if (0 < len1) {
239                                         pull_string(p, blob->data + ptr, sizeof(p), 
240                                                     len1, 
241                                                     STR_UNICODE|STR_NOALIGN);
242                                         (*ps) = talloc_strdup(mem_ctx, p);
243                                         if (!(*ps)) {
244                                                 return False;
245                                         }
246                                 } else {
247                                         (*ps) = "";
248                                 }
249                         }
250                         break;
251                 case 'A':
252                         NEED_DATA(8);
253                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
254                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
255                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
256
257                         ps = (const char **)va_arg(ap, char **);
258                         /* make sure its in the right format - be strict */
259                         if (len1 == 0 && len2 == 0) {
260                                 *ps = "";
261                         } else {
262                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
263                                         return False;
264                                 }
265
266                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
267                                         return False;   
268
269                                 if (0 < len1) {
270                                         pull_string(p, blob->data + ptr, sizeof(p), 
271                                                     len1, 
272                                                     STR_ASCII|STR_NOALIGN);
273                                         (*ps) = talloc_strdup(mem_ctx, p);
274                                         if (!(*ps)) {
275                                                 return False;
276                                         }
277                                 } else {
278                                         (*ps) = "";
279                                 }
280                         }
281                         break;
282                 case 'B':
283                         NEED_DATA(8);
284                         len1 = SVAL(blob->data, head_ofs); head_ofs += 2;
285                         len2 = SVAL(blob->data, head_ofs); head_ofs += 2;
286                         ptr =  IVAL(blob->data, head_ofs); head_ofs += 4;
287
288                         b = (DATA_BLOB *)va_arg(ap, void *);
289                         if (len1 == 0 && len2 == 0) {
290                                 *b = data_blob_talloc(mem_ctx, NULL, 0);
291                         } else {
292                                 /* make sure its in the right format - be strict */
293                                 if ((len1 != len2) || (ptr + len1 < ptr) || (ptr + len1 < len1) || (ptr + len1 > blob->length)) {
294                                         return False;
295                                 }
296
297                                 if (blob->data + ptr < (uint8_t *)ptr || blob->data + ptr < blob->data)
298                                         return False;   
299                         
300                                 *b = data_blob_talloc(mem_ctx, blob->data + ptr, len1);
301                         }
302                         break;
303                 case 'b':
304                         b = (DATA_BLOB *)va_arg(ap, void *);
305                         len1 = va_arg(ap, uint_t);
306                         /* make sure its in the right format - be strict */
307                         NEED_DATA(len1);
308                         if (blob->data + head_ofs < (uint8_t *)head_ofs || blob->data + head_ofs < blob->data)
309                                 return False;   
310                         
311                         *b = data_blob_talloc(mem_ctx, blob->data + head_ofs, len1);
312                         head_ofs += len1;
313                         break;
314                 case 'd':
315                         v = va_arg(ap, uint32_t *);
316                         NEED_DATA(4);
317                         *v = IVAL(blob->data, head_ofs); head_ofs += 4;
318                         break;
319                 case 'C':
320                         s = va_arg(ap, char *);
321
322                         if (blob->data + head_ofs < (uint8_t *)head_ofs || blob->data + head_ofs < blob->data)
323                                 return False;   
324         
325                         head_ofs += pull_string(p, blob->data+head_ofs, sizeof(p), 
326                                                 blob->length - head_ofs, 
327                                                 STR_ASCII|STR_TERMINATE);
328                         if (strcmp(s, p) != 0) {
329                                 return False;
330                         }
331                         break;
332                 }
333         }
334         va_end(ap);
335
336         return True;
337 }