dns: Simplify logic a bit
[kai/samba-autobuild/.git] / source4 / dns_server / dns_crypto.c
index 14dc4ca69b5c95d8db108237b2114f03eda05661..740e1e4dd533747366ff494e15b179c9065790d5 100644 (file)
@@ -29,6 +29,9 @@
 #include "auth/auth.h"
 #include "auth/gensec/gensec.h"
 
+#undef DBGC_CLASS
+#define DBGC_CLASS DBGC_DNS
+
 static WERROR dns_copy_tsig(TALLOC_CTX *mem_ctx,
                            struct dns_res_rec *old,
                            struct dns_res_rec *new_rec)
@@ -97,7 +100,6 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        WERROR werror;
        NTSTATUS status;
        enum ndr_err_code ndr_err;
-       bool found_tsig = false;
        uint16_t i, arcount = 0;
        DATA_BLOB tsig_blob, fake_tsig_blob, sig;
        uint8_t *buffer = NULL;
@@ -110,27 +112,27 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        /* Find the first TSIG record in the additional records */
        for (i=0; i < packet->arcount; i++) {
                if (packet->additional[i].rr_type == DNS_QTYPE_TSIG) {
-                       found_tsig = true;
                        break;
                }
        }
 
-       if (!found_tsig) {
+       if (i == packet->arcount) {
+               /* no TSIG around */
                return WERR_OK;
        }
 
        /* The TSIG record needs to be the last additional record */
-       if (found_tsig && i + 1 != packet->arcount) {
-               DEBUG(0, ("TSIG record not the last additional record!\n"));
+       if (i + 1 != packet->arcount) {
+               DEBUG(1, ("TSIG record not the last additional record!\n"));
                return DNS_ERR(FORMAT_ERROR);
        }
 
        /* We got a TSIG, so we need to sign our reply */
        state->sign = true;
 
-       state->tsig = talloc_zero(mem_ctx, struct dns_res_rec);
+       state->tsig = talloc_zero(state->mem_ctx, struct dns_res_rec);
        if (state->tsig == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
        werror = dns_copy_tsig(state->tsig, &packet->additional[i],
@@ -143,25 +145,45 @@ WERROR dns_verify_tsig(struct dns_server *dns,
 
        tkey = dns_find_tkey(dns->tkeys, state->tsig->name);
        if (tkey == NULL) {
+               /*
+                * We must save the name for use in the TSIG error
+                * response and have no choice here but to save the
+                * keyname from the TSIG request.
+                */
+               state->key_name = talloc_strdup(state->mem_ctx,
+                                               state->tsig->name);
+               if (state->key_name == NULL) {
+                       return WERR_NOT_ENOUGH_MEMORY;
+               }
                state->tsig_error = DNS_RCODE_BADKEY;
                return DNS_ERR(NOTAUTH);
        }
 
+       /*
+        * Remember the keyname that found an existing tkey, used
+        * later to fetch the key with dns_find_tkey() when signing
+        * and adding a TSIG record with MAC.
+        */
+       state->key_name = talloc_strdup(state->mem_ctx, tkey->name);
+       if (state->key_name == NULL) {
+               return WERR_NOT_ENOUGH_MEMORY;
+       }
+
        /* FIXME: check TSIG here */
        if (check_rec == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
        /* first build and verify check packet */
        check_rec->name = talloc_strdup(check_rec, tkey->name);
        if (check_rec->name == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
        check_rec->rr_class = DNS_QCLASS_ANY;
        check_rec->ttl = 0;
        check_rec->algorithm_name = talloc_strdup(check_rec, tkey->algorithm);
        if (check_rec->algorithm_name == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
        check_rec->time_prefix = 0;
        check_rec->time = state->tsig->rdata.tsig_record.time;
@@ -192,7 +214,7 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        buffer_len = packet_len + fake_tsig_blob.length;
        buffer = talloc_zero_array(mem_ctx, uint8_t, buffer_len);
        if (buffer == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
        memcpy(buffer, in->data, packet_len);
@@ -201,12 +223,9 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        sig.length = state->tsig->rdata.tsig_record.mac_size;
        sig.data = talloc_memdup(mem_ctx, state->tsig->rdata.tsig_record.mac, sig.length);
        if (sig.data == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
-       /*FIXME: Why is there too much padding? */
-       buffer_len -= 2;
-
        /* Now we also need to count down the additional record counter */
        arcount = RSVAL(buffer, 10);
        RSSVAL(buffer, 10, arcount-1);
@@ -214,11 +233,12 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        status = gensec_check_packet(tkey->gensec, buffer, buffer_len,
                                    buffer, buffer_len, &sig);
        if (NT_STATUS_EQUAL(NT_STATUS_ACCESS_DENIED, status)) {
-               return DNS_ERR(BADKEY);
+               state->tsig_error = DNS_RCODE_BADSIG;
+               return DNS_ERR(NOTAUTH);
        }
 
        if (!NT_STATUS_IS_OK(status)) {
-               DEBUG(0, ("Verifying tsig failed: %s\n", nt_errstr(status)));
+               DEBUG(1, ("Verifying tsig failed: %s\n", nt_errstr(status)));
                return ntstatus_to_werror(status);
        }
 
@@ -227,49 +247,37 @@ WERROR dns_verify_tsig(struct dns_server *dns,
        return WERR_OK;
 }
 
-WERROR dns_sign_tsig(struct dns_server *dns,
-                    TALLOC_CTX *mem_ctx,
-                    struct dns_request_state *state,
-                    struct dns_name_packet *packet,
-                    uint16_t error)
+static WERROR dns_tsig_compute_mac(TALLOC_CTX *mem_ctx,
+                                  struct dns_request_state *state,
+                                  struct dns_name_packet *packet,
+                                  struct dns_server_tkey *tkey,
+                                  time_t current_time,
+                                  DATA_BLOB *_psig)
 {
-       WERROR werror;
        NTSTATUS status;
        enum ndr_err_code ndr_err;
-       time_t current_time = time(NULL);
        DATA_BLOB packet_blob, tsig_blob, sig;
        uint8_t *buffer = NULL;
+       uint8_t *p = NULL;
        size_t buffer_len = 0;
-       struct dns_server_tkey * tkey = NULL;
-       struct dns_res_rec *tsig = talloc_zero(mem_ctx, struct dns_res_rec);
-
        struct dns_fake_tsig_rec *check_rec = talloc_zero(mem_ctx,
                        struct dns_fake_tsig_rec);
-
-       if (tsig == NULL) {
-               return WERR_NOMEM;
-       }
+       size_t mac_size = 0;
 
        if (check_rec == NULL) {
-               return WERR_NOMEM;
-       }
-
-       tkey = dns_find_tkey(dns->tkeys, state->key_name);
-       if (tkey == NULL) {
-               /* FIXME: read up on what to do when we can't find a key */
-               return WERR_OK;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
        /* first build and verify check packet */
        check_rec->name = talloc_strdup(check_rec, tkey->name);
        if (check_rec->name == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
        check_rec->rr_class = DNS_QCLASS_ANY;
        check_rec->ttl = 0;
        check_rec->algorithm_name = talloc_strdup(check_rec, tkey->algorithm);
        if (check_rec->algorithm_name == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
        check_rec->time_prefix = 0;
        check_rec->time = current_time;
@@ -294,15 +302,44 @@ WERROR dns_sign_tsig(struct dns_server *dns,
                return DNS_ERR(SERVER_FAILURE);
        }
 
-       buffer_len = packet_blob.length + tsig_blob.length;
+       if (state->tsig != NULL) {
+               mac_size = state->tsig->rdata.tsig_record.mac_size;
+       }
+
+       buffer_len = mac_size;
+
+       buffer_len += packet_blob.length;
+       if (buffer_len < packet_blob.length) {
+               return WERR_INVALID_PARAMETER;
+       }
+       buffer_len += tsig_blob.length;
+       if (buffer_len < tsig_blob.length) {
+               return WERR_INVALID_PARAMETER;
+       }
+
        buffer = talloc_zero_array(mem_ctx, uint8_t, buffer_len);
        if (buffer == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
+       }
+
+       p = buffer;
+
+       /*
+        * RFC 2845 "4.2 TSIG on Answers", how to lay out the buffer
+        * that we're going to sign:
+        * 1. MAC of request (if present)
+        * 2. Outgoing packet
+        * 3. TSIG record
+        */
+       if (mac_size > 0) {
+               memcpy(p, state->tsig->rdata.tsig_record.mac, mac_size);
+               p += mac_size;
        }
 
-       memcpy(buffer, packet_blob.data, packet_blob.length);
-       memcpy(buffer+packet_blob.length, tsig_blob.data, tsig_blob.length);
+       memcpy(p, packet_blob.data, packet_blob.length);
+       p += packet_blob.length;
 
+       memcpy(p, tsig_blob.data, tsig_blob.length);
 
        status = gensec_sign_packet(tkey->gensec, mem_ctx, buffer, buffer_len,
                                    buffer, buffer_len, &sig);
@@ -310,38 +347,75 @@ WERROR dns_sign_tsig(struct dns_server *dns,
                return ntstatus_to_werror(status);
        }
 
-       tsig->name = talloc_strdup(tsig, check_rec->name);
+       *_psig = sig;
+       return WERR_OK;
+}
+
+WERROR dns_sign_tsig(struct dns_server *dns,
+                    TALLOC_CTX *mem_ctx,
+                    struct dns_request_state *state,
+                    struct dns_name_packet *packet,
+                    uint16_t error)
+{
+       WERROR werror;
+       time_t current_time = time(NULL);
+       struct dns_res_rec *tsig = NULL;
+       DATA_BLOB sig = (DATA_BLOB) {
+               .data = NULL,
+               .length = 0
+       };
+
+       tsig = talloc_zero(mem_ctx, struct dns_res_rec);
+       if (tsig == NULL) {
+               return WERR_NOT_ENOUGH_MEMORY;
+       }
+
+       if (state->tsig_error == DNS_RCODE_OK) {
+               struct dns_server_tkey *tkey = dns_find_tkey(
+                       dns->tkeys, state->key_name);
+               if (tkey == NULL) {
+                       return DNS_ERR(SERVER_FAILURE);
+               }
+
+               werror = dns_tsig_compute_mac(mem_ctx, state, packet,
+                                             tkey, current_time, &sig);
+               if (!W_ERROR_IS_OK(werror)) {
+                       return werror;
+               }
+       }
+
+       tsig->name = talloc_strdup(tsig, state->key_name);
        if (tsig->name == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
-       tsig->rr_class = check_rec->rr_class;
+       tsig->rr_class = DNS_QCLASS_ANY;
        tsig->rr_type = DNS_QTYPE_TSIG;
        tsig->ttl = 0;
        tsig->length = UINT16_MAX;
-       tsig->rdata.tsig_record.algorithm_name = talloc_strdup(tsig,
-                       check_rec->algorithm_name);
-       tsig->rdata.tsig_record.time_prefix = check_rec->time_prefix;
-       tsig->rdata.tsig_record.time = check_rec->time;
-       tsig->rdata.tsig_record.fudge = check_rec->fudge;
+       tsig->rdata.tsig_record.algorithm_name = talloc_strdup(tsig, "gss-tsig");
+       tsig->rdata.tsig_record.time_prefix = 0;
+       tsig->rdata.tsig_record.time = current_time;
+       tsig->rdata.tsig_record.fudge = 300;
        tsig->rdata.tsig_record.error = state->tsig_error;
        tsig->rdata.tsig_record.original_id = packet->id;
        tsig->rdata.tsig_record.other_size = 0;
        tsig->rdata.tsig_record.other_data = NULL;
-       tsig->rdata.tsig_record.mac_size = sig.length;
-       tsig->rdata.tsig_record.mac = talloc_memdup(tsig, sig.data, sig.length);
-
+       if (sig.length > 0) {
+               tsig->rdata.tsig_record.mac_size = sig.length;
+               tsig->rdata.tsig_record.mac = talloc_memdup(tsig, sig.data, sig.length);
+       }
 
        if (packet->arcount == 0) {
                packet->additional = talloc_zero(mem_ctx, struct dns_res_rec);
                if (packet->additional == NULL) {
-                       return WERR_NOMEM;
+                       return WERR_NOT_ENOUGH_MEMORY;
                }
        }
        packet->additional = talloc_realloc(mem_ctx, packet->additional,
                                            struct dns_res_rec,
                                            packet->arcount + 1);
        if (packet->additional == NULL) {
-               return WERR_NOMEM;
+               return WERR_NOT_ENOUGH_MEMORY;
        }
 
        werror = dns_copy_tsig(mem_ctx, tsig,