dcerpc server output now copes with the client blocking part way
[samba.git] / source4 / rpc_server / dcerpc_server.c
index a4f5fb9768ec208172fb3cd9989518759f41b33e..0553537cb539f12467ca1139d7b9b8a3f8ba1090 100644 (file)
@@ -885,14 +885,23 @@ NTSTATUS dcesrv_input(struct dcesrv_connection *dce_conn, const DATA_BLOB *data)
 }
 
 /*
-  retrieve some output from a dcerpc server. The amount of data that
-  is wanted is in data->length and data->data is already allocated
-  to hold that much data.
+  retrieve some output from a dcerpc server
+  The caller supplies a function that will be called to do the
+  actual output. 
+
+  The first argument to write_fn() will be 'private', the second will
+  be a pointer to a buffer containing the data to be sent and the 3rd
+  will be the number of bytes to be sent.
+
+  write_fn() should return the number of bytes successfully written.
 */
-NTSTATUS dcesrv_output(struct dcesrv_connection *dce_conn, DATA_BLOB *data)
+NTSTATUS dcesrv_output(struct dcesrv_connection *dce_conn, 
+                      void *private,
+                      ssize_t (*write_fn)(void *, const void *, size_t))
 {
        struct dcesrv_call_state *call;
        struct dcesrv_call_reply *rep;
+       ssize_t nwritten;
 
        call = dce_conn->call_list;
        if (!call || !call->replies) {
@@ -900,13 +909,15 @@ NTSTATUS dcesrv_output(struct dcesrv_connection *dce_conn, DATA_BLOB *data)
        }
        rep = call->replies;
 
-       if (data->length >= rep->data.length) {
-               data->length = rep->data.length;
+       nwritten = write_fn(private, rep->data.data, rep->data.length);
+       if (nwritten == -1) {
+               /* TODO: hmm, how do we cope with this? destroy the
+                  connection perhaps? */
+               return NT_STATUS_UNSUCCESSFUL;
        }
 
-       memcpy(data->data, rep->data.data, data->length);
-       rep->data.length -= data->length;
-       rep->data.data += data->length;
+       rep->data.length -= nwritten;
+       rep->data.data += nwritten;
 
        if (rep->data.length == 0) {
                /* we're done with this section of the call */
@@ -922,6 +933,30 @@ NTSTATUS dcesrv_output(struct dcesrv_connection *dce_conn, DATA_BLOB *data)
        return NT_STATUS_OK;
 }
 
+
+/*
+  write_fn() for dcesrv_output_blob()
+*/
+static ssize_t dcesrv_output_blob_write_fn(void *private, const void *buf, size_t count)
+{
+       DATA_BLOB *blob = private;
+       if (count < blob->length) {
+               blob->length = count;
+       }
+       memcpy(blob->data, buf, blob->length);
+       return blob->length;
+}
+
+/*
+  a simple wrapper for dcesrv_output() for when we want to output
+  into a data blob
+*/
+NTSTATUS dcesrv_output_blob(struct dcesrv_connection *dce_conn, 
+                           DATA_BLOB *blob)
+{
+       return dcesrv_output(dce_conn, blob, dcesrv_output_blob_write_fn);
+}
+
 /*
   initialise the dcerpc server context
 */