s4:rpc_server: add dcesrv_iface_state_{store,find}_{assoc,conn}() helpers
[samba.git] / source4 / rpc_server / dcerpc_server.c
index 7949c66323ac093eb22b6c4ac86792f027c445f6..a79a569477203eb66753db153f894a1d1b12be11 100644 (file)
@@ -459,25 +459,69 @@ _PUBLIC_ NTSTATUS dcesrv_interface_register(struct dcesrv_context *dce_ctx,
        return NT_STATUS_OK;
 }
 
-NTSTATUS dcesrv_inherited_session_key(struct dcesrv_connection *p,
-                                     DATA_BLOB *session_key)
+static NTSTATUS dcesrv_session_info_session_key(struct dcesrv_auth *auth,
+                                               DATA_BLOB *session_key)
 {
-       if (p->auth_state.session_info->session_key.length) {
-               *session_key = p->auth_state.session_info->session_key;
-               return NT_STATUS_OK;
+       if (auth->session_info == NULL) {
+               return NT_STATUS_NO_USER_SESSION_KEY;
+       }
+
+       if (auth->session_info->session_key.length == 0) {
+               return NT_STATUS_NO_USER_SESSION_KEY;
+       }
+
+       *session_key = auth->session_info->session_key;
+       return NT_STATUS_OK;
+}
+
+static NTSTATUS dcesrv_remote_session_key(struct dcesrv_auth *auth,
+                                         DATA_BLOB *session_key)
+{
+       if (auth->auth_type != DCERPC_AUTH_TYPE_NONE) {
+               return NT_STATUS_NO_USER_SESSION_KEY;
        }
-       return NT_STATUS_NO_USER_SESSION_KEY;
+
+       return dcesrv_session_info_session_key(auth, session_key);
+}
+
+static NTSTATUS dcesrv_local_fixed_session_key(struct dcesrv_auth *auth,
+                                              DATA_BLOB *session_key)
+{
+       return dcerpc_generic_session_key(NULL, session_key);
 }
 
 /*
-  fetch the user session key - may be default (above) or the SMB session key
+ * Fetch the authentication session key if available.
+ *
+ * This is the key generated by a gensec authentication.
+ *
+ */
+_PUBLIC_ NTSTATUS dcesrv_auth_session_key(struct dcesrv_call_state *call,
+                                         DATA_BLOB *session_key)
+{
+       struct dcesrv_auth *auth = call->auth_state;
+
+       return dcesrv_session_info_session_key(auth, session_key);
+}
 
-  The key is always truncated to 16 bytes 
+/*
+ * Fetch the transport session key if available.
+ * Typically this is the SMB session key
+ * or a fixed key for local transports.
+ *
+ * The key is always truncated to 16 bytes.
 */
-_PUBLIC_ NTSTATUS dcesrv_fetch_session_key(struct dcesrv_connection *p,
-                                 DATA_BLOB *session_key)
+_PUBLIC_ NTSTATUS dcesrv_transport_session_key(struct dcesrv_call_state *call,
+                                              DATA_BLOB *session_key)
 {
-       NTSTATUS status = p->auth_state.session_key(p, session_key);
+       struct dcesrv_auth *auth = call->auth_state;
+       NTSTATUS status;
+
+       if (auth->session_key_fn == NULL) {
+               return NT_STATUS_NO_USER_SESSION_KEY;
+       }
+
+       status = auth->session_key_fn(auth, session_key);
        if (!NT_STATUS_IS_OK(status)) {
                return status;
        }
@@ -487,10 +531,41 @@ _PUBLIC_ NTSTATUS dcesrv_fetch_session_key(struct dcesrv_connection *p,
        return NT_STATUS_OK;
 }
 
+static struct dcesrv_auth *dcesrv_auth_create(struct dcesrv_connection *conn)
+{
+       const struct dcesrv_endpoint *ep = conn->endpoint;
+       enum dcerpc_transport_t transport =
+               dcerpc_binding_get_transport(ep->ep_description);
+       struct dcesrv_auth *auth = NULL;
+
+       auth = talloc_zero(conn, struct dcesrv_auth);
+       if (auth == NULL) {
+               return NULL;
+       }
+
+       switch (transport) {
+       case NCACN_NP:
+               auth->session_key_fn = dcesrv_remote_session_key;
+               break;
+       case NCALRPC:
+       case NCACN_UNIX_STREAM:
+               auth->session_key_fn = dcesrv_local_fixed_session_key;
+               break;
+       default:
+               /*
+                * All other's get a NULL pointer, which
+                * results in NT_STATUS_NO_USER_SESSION_KEY
+                */
+               break;
+       }
+
+       return auth;
+}
+
 /*
   connect to a dcerpc endpoint
 */
-_PUBLIC_ NTSTATUS dcesrv_endpoint_connect(struct dcesrv_context *dce_ctx,
+static NTSTATUS dcesrv_endpoint_connect(struct dcesrv_context *dce_ctx,
                                 TALLOC_CTX *mem_ctx,
                                 const struct dcesrv_endpoint *ep,
                                 struct auth_session_info *session_info,
@@ -500,6 +575,7 @@ _PUBLIC_ NTSTATUS dcesrv_endpoint_connect(struct dcesrv_context *dce_ctx,
                                 uint32_t state_flags,
                                 struct dcesrv_connection **_p)
 {
+       struct dcesrv_auth *auth = NULL;
        struct dcesrv_connection *p;
 
        if (!session_info) {
@@ -509,16 +585,9 @@ _PUBLIC_ NTSTATUS dcesrv_endpoint_connect(struct dcesrv_context *dce_ctx,
        p = talloc_zero(mem_ctx, struct dcesrv_connection);
        NT_STATUS_HAVE_NO_MEMORY(p);
 
-       if (!talloc_reference(p, session_info)) {
-               talloc_free(p);
-               return NT_STATUS_NO_MEMORY;
-       }
-
        p->dce_ctx = dce_ctx;
        p->endpoint = ep;
        p->packet_log_dir = lpcfg_lock_directory(dce_ctx->lp_ctx);
-       p->auth_state.session_info = session_info;
-       p->auth_state.session_key = dcesrv_generic_session_key;
        p->event_ctx = event_ctx;
        p->msg_ctx = msg_ctx;
        p->server_id = server_id;
@@ -528,6 +597,20 @@ _PUBLIC_ NTSTATUS dcesrv_endpoint_connect(struct dcesrv_context *dce_ctx,
        p->max_xmit_frag = 5840;
        p->max_total_request_size = DCERPC_NCACN_REQUEST_DEFAULT_MAX_SIZE;
 
+       auth = dcesrv_auth_create(p);
+       if (auth == NULL) {
+               talloc_free(p);
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       auth->session_info = talloc_reference(auth, session_info);
+       if (auth->session_info == NULL) {
+               talloc_free(p);
+               return NT_STATUS_NO_MEMORY;
+       }
+
+       p->default_auth_state = auth;
+
        /*
         * For now we only support NDR32.
         */
@@ -583,8 +666,8 @@ static void dcesrv_call_disconnect_after(struct dcesrv_call_state *call,
 
        call->conn->allow_bind = false;
        call->conn->allow_alter = false;
-       call->conn->allow_auth3 = false;
-       call->conn->allow_request = false;
+
+       call->conn->default_auth_state->auth_invalid = true;
 
        call->terminate_reason = talloc_strdup(call, reason);
        if (call->terminate_reason == NULL) {
@@ -887,7 +970,7 @@ static NTSTATUS dcesrv_bind(struct dcesrv_call_state *call)
        uint16_t max_rep = 0;
        const char *ep_prefix = "";
        const char *endpoint = NULL;
-       struct dcesrv_auth *auth = &call->conn->auth_state;
+       struct dcesrv_auth *auth = call->auth_state;
        struct dcerpc_ack_ctx *ack_ctx_list = NULL;
        struct dcerpc_ack_ctx *ack_features = NULL;
        struct tevent_req *subreq = NULL;
@@ -1024,7 +1107,7 @@ static NTSTATUS dcesrv_bind(struct dcesrv_call_state *call)
                                DCERPC_BIND_TIME_KEEP_CONNECTION_ON_ORPHAN;
                }
 
-               call->conn->bind_time_features = a->reason.negotiate;
+               call->conn->assoc_group->bind_time_features = a->reason.negotiate;
        }
 
        /*
@@ -1209,15 +1292,15 @@ static void dcesrv_auth3_done(struct tevent_req *subreq);
 static NTSTATUS dcesrv_auth3(struct dcesrv_call_state *call)
 {
        struct dcesrv_connection *conn = call->conn;
-       struct dcesrv_auth *auth = &call->conn->auth_state;
+       struct dcesrv_auth *auth = call->auth_state;
        struct tevent_req *subreq = NULL;
        NTSTATUS status;
 
-       if (!call->conn->allow_auth3) {
+       if (!auth->auth_started) {
                return dcesrv_fault_disconnect(call, DCERPC_NCA_S_PROTO_ERROR);
        }
 
-       if (call->conn->auth_state.auth_finished) {
+       if (auth->auth_finished) {
                return dcesrv_fault_disconnect(call, DCERPC_NCA_S_PROTO_ERROR);
        }
 
@@ -1246,7 +1329,7 @@ static NTSTATUS dcesrv_auth3(struct dcesrv_call_state *call)
                 * In anycase we mark the connection as
                 * invalid.
                 */
-               call->conn->auth_state.auth_invalid = true;
+               auth->auth_invalid = true;
                if (call->fault_code != 0) {
                        return dcesrv_fault_disconnect(call, call->fault_code);
                }
@@ -1271,6 +1354,7 @@ static void dcesrv_auth3_done(struct tevent_req *subreq)
                tevent_req_callback_data(subreq,
                struct dcesrv_call_state);
        struct dcesrv_connection *conn = call->conn;
+       struct dcesrv_auth *auth = call->auth_state;
        NTSTATUS status;
 
        status = gensec_update_recv(subreq, call,
@@ -1286,7 +1370,7 @@ static void dcesrv_auth3_done(struct tevent_req *subreq)
                 * In anycase we mark the connection as
                 * invalid.
                 */
-               call->conn->auth_state.auth_invalid = true;
+               auth->auth_invalid = true;
                if (call->fault_code != 0) {
                        status = dcesrv_fault_disconnect(call, call->fault_code);
                        dcesrv_conn_auth_wait_finished(conn, status);
@@ -1546,7 +1630,7 @@ static NTSTATUS dcesrv_alter(struct dcesrv_call_state *call)
        bool auth_ok = false;
        struct ncacn_packet *pkt = &call->ack_pkt;
        uint32_t extra_flags = 0;
-       struct dcesrv_auth *auth = &call->conn->auth_state;
+       struct dcesrv_auth *auth = call->auth_state;
        struct dcerpc_ack_ctx *ack_ctx_list = NULL;
        struct tevent_req *subreq = NULL;
        size_t i;
@@ -1629,9 +1713,7 @@ static NTSTATUS dcesrv_alter(struct dcesrv_call_state *call)
 
        /* handle any authentication that is being requested */
        if (!auth_ok) {
-               if (call->in_auth_info.auth_type !=
-                   call->conn->auth_state.auth_type)
-               {
+               if (call->in_auth_info.auth_type != auth->auth_type) {
                        return dcesrv_fault_disconnect(call,
                                        DCERPC_FAULT_SEC_PKG_ERROR);
                }
@@ -1722,7 +1804,8 @@ static void dcesrv_save_call(struct dcesrv_call_state *call, const char *why)
 static NTSTATUS dcesrv_check_verification_trailer(struct dcesrv_call_state *call)
 {
        TALLOC_CTX *frame = talloc_stackframe();
-       const uint32_t bitmask1 = call->conn->auth_state.client_hdr_signing ?
+       const struct dcesrv_auth *auth = call->auth_state;
+       const uint32_t bitmask1 = auth->client_hdr_signing ?
                DCERPC_SEC_VT_CLIENT_SUPPORTS_HEADER_SIGNING : 0;
        const struct dcerpc_sec_vt_pcontext pcontext = {
                .abstract_syntax = call->context->iface->syntax_id,
@@ -1761,18 +1844,19 @@ done:
 static NTSTATUS dcesrv_request(struct dcesrv_call_state *call)
 {
        const struct dcesrv_endpoint *endpoint = call->conn->endpoint;
+       struct dcesrv_auth *auth = call->auth_state;
        enum dcerpc_transport_t transport =
                dcerpc_binding_get_transport(endpoint->ep_description);
        struct ndr_pull *pull;
        NTSTATUS status;
 
-       if (!call->conn->allow_request) {
+       if (!auth->auth_finished) {
                return dcesrv_fault_disconnect(call, DCERPC_NCA_S_PROTO_ERROR);
        }
 
        /* if authenticated, and the mech we use can't do async replies, don't use them... */
-       if (call->conn->auth_state.gensec_security && 
-           !gensec_have_feature(call->conn->auth_state.gensec_security, GENSEC_FEATURE_ASYNC_REPLIES)) {
+       if (auth->gensec_security != NULL &&
+           !gensec_have_feature(auth->gensec_security, GENSEC_FEATURE_ASYNC_REPLIES)) {
                call->state_flags &= ~DCESRV_CALL_STATE_FLAG_MAY_ASYNC;
        }
 
@@ -1781,7 +1865,7 @@ static NTSTATUS dcesrv_request(struct dcesrv_call_state *call)
                                        DCERPC_PFC_FLAG_DID_NOT_EXECUTE);
        }
 
-       switch (call->conn->auth_state.auth_level) {
+       switch (auth->auth_level) {
        case DCERPC_AUTH_LEVEL_NONE:
        case DCERPC_AUTH_LEVEL_PACKET:
        case DCERPC_AUTH_LEVEL_INTEGRITY:
@@ -1798,8 +1882,8 @@ static NTSTATUS dcesrv_request(struct dcesrv_call_state *call)
                                  "to [%s] with auth[type=0x%x,level=0x%x] "
                                  "on [%s] from [%s]\n",
                                  __func__, call->context->iface->name,
-                                 call->conn->auth_state.auth_type,
-                                 call->conn->auth_state.auth_level,
+                                 auth->auth_type,
+                                 auth->auth_level,
                                  derpc_transport_string_by_transport(transport),
                                  addr));
                        return dcesrv_fault(call, DCERPC_FAULT_ACCESS_DENIED);
@@ -1807,7 +1891,7 @@ static NTSTATUS dcesrv_request(struct dcesrv_call_state *call)
                break;
        }
 
-       if (call->conn->auth_state.auth_level < call->context->min_auth_level) {
+       if (auth->auth_level < call->context->min_auth_level) {
                char *addr;
 
                addr = tsocket_address_string(call->conn->remote_address, call);
@@ -1818,8 +1902,8 @@ static NTSTATUS dcesrv_request(struct dcesrv_call_state *call)
                          __func__,
                          call->context->min_auth_level,
                          call->context->iface->name,
-                         call->conn->auth_state.auth_type,
-                         call->conn->auth_state.auth_level,
+                         auth->auth_type,
+                         auth->auth_level,
                          derpc_transport_string_by_transport(transport),
                          addr));
                return dcesrv_fault(call, DCERPC_FAULT_ACCESS_DENIED);
@@ -1938,6 +2022,8 @@ static NTSTATUS dcesrv_process_ncacn_packet(struct dcesrv_connection *dce_conn,
        talloc_steal(call, blob.data);
        call->pkt = *pkt;
 
+       call->auth_state = dce_conn->default_auth_state;
+
        talloc_set_destructor(call, dcesrv_call_dequeue);
 
        if (call->conn->allow_bind) {
@@ -1951,7 +2037,7 @@ static NTSTATUS dcesrv_process_ncacn_packet(struct dcesrv_connection *dce_conn,
        /* we have to check the signing here, before combining the
           pdus */
        if (call->pkt.ptype == DCERPC_PKT_REQUEST) {
-               if (!call->conn->allow_request) {
+               if (!call->auth_state->auth_finished) {
                        return dcesrv_fault_disconnect(call,
                                        DCERPC_NCA_S_PROTO_ERROR);
                }
@@ -2388,9 +2474,9 @@ static void dcesrv_terminate_connection(struct dcesrv_connection *dce_conn, cons
        dce_conn->wait_private = NULL;
 
        dce_conn->allow_bind = false;
-       dce_conn->allow_auth3 = false;
        dce_conn->allow_alter = false;
-       dce_conn->allow_request = false;
+
+       dce_conn->default_auth_state->auth_invalid = true;
 
        if (dce_conn->pending_call_list == NULL) {
                char *full_reason = talloc_asprintf(dce_conn, "dcesrv: %s", reason);
@@ -2661,7 +2747,6 @@ static void dcesrv_sock_accept(struct stream_connection *srv_conn)
        }
 
        if (transport == NCACN_NP) {
-               dcesrv_conn->auth_state.session_key = dcesrv_inherited_session_key;
                dcesrv_conn->stream = talloc_move(dcesrv_conn,
                                                  &srv_conn->tstream);
        } else {
@@ -3145,7 +3230,8 @@ NTSTATUS dcesrv_add_ep(struct dcesrv_context *dce_ctx,
  */
 _PUBLIC_ struct cli_credentials *dcesrv_call_credentials(struct dcesrv_call_state *dce_call)
 {
-       return dce_call->conn->auth_state.session_info->credentials;
+       struct dcesrv_auth *auth = dce_call->auth_state;
+       return auth->session_info->credentials;
 }
 
 /**
@@ -3153,8 +3239,9 @@ _PUBLIC_ struct cli_credentials *dcesrv_call_credentials(struct dcesrv_call_stat
  */
 _PUBLIC_ bool dcesrv_call_authenticated(struct dcesrv_call_state *dce_call)
 {
+       struct dcesrv_auth *auth = dce_call->auth_state;
        enum security_user_level level;
-       level = security_session_user_level(dce_call->conn->auth_state.session_info, NULL);
+       level = security_session_user_level(auth->session_info, NULL);
        return level >= SECURITY_USER;
 }
 
@@ -3163,5 +3250,32 @@ _PUBLIC_ bool dcesrv_call_authenticated(struct dcesrv_call_state *dce_call)
  */
 _PUBLIC_ const char *dcesrv_call_account_name(struct dcesrv_call_state *dce_call)
 {
-       return dce_call->context->conn->auth_state.session_info->info->account_name;
+       struct dcesrv_auth *auth = dce_call->auth_state;
+       return auth->session_info->info->account_name;
+}
+
+/**
+ * retrieve session_info from a dce_call
+ */
+_PUBLIC_ struct auth_session_info *dcesrv_call_session_info(struct dcesrv_call_state *dce_call)
+{
+       struct dcesrv_auth *auth = dce_call->auth_state;
+       return auth->session_info;
+}
+
+/**
+ * retrieve auth type/level from a dce_call
+ */
+_PUBLIC_ void dcesrv_call_auth_info(struct dcesrv_call_state *dce_call,
+                                   enum dcerpc_AuthType *auth_type,
+                                   enum dcerpc_AuthLevel *auth_level)
+{
+       struct dcesrv_auth *auth = dce_call->auth_state;
+
+       if (auth_type != NULL) {
+               *auth_type = auth->auth_type;
+       }
+       if (auth_level != NULL) {
+               *auth_level = auth->auth_level;
+       }
 }