lib: Split out write_data[_iov]
[obnox/samba/samba-obnox.git] / source3 / utils / smbfilter.c
index 2f78140897b9cd52a6c8bfb63bb7207310304023..ff966a8c592fc54842206352ad28feb5d1e39db7 100644 (file)
@@ -2,23 +2,27 @@
    Unix SMB/CIFS implementation.
    SMB filter/socket plugin
    Copyright (C) Andrew Tridgell 1999
-   
+
    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.
-   
+
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
-   
+
    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
 
 #include "includes.h"
+#include "system/filesys.h"
+#include "system/select.h"
 #include "../lib/util/select.h"
+#include "libsmb/nmblib.h"
+#include "lib/sys_rw_data.h"
 
 #define SECURITY_MASK 0
 #define SECURITY_SET  0
@@ -32,7 +36,6 @@
 #define CLI_CAPABILITY_SET  0
 
 static char *netbiosname;
-static char packet[BUFFER_SIZE];
 
 static void save_file(const char *fname, void *ppacket, size_t length)
 {
@@ -44,6 +47,7 @@ static void save_file(const char *fname, void *ppacket, size_t length)
        }
        if (write(fd, ppacket, length) != length) {
                fprintf(stderr,"Failed to write %s\n", fname);
+               close(fd);
                return;
        }
        close(fd);
@@ -174,33 +178,54 @@ static void filter_child(int c, struct sockaddr_storage *dest_ss)
 {
        NTSTATUS status;
        int s = -1;
+       char packet[128*1024];
 
        /* we have a connection from a new client, now connect to the server */
-       status = open_socket_out(dest_ss, 445, LONG_CONNECT_TIMEOUT, &s);
-
-       if (s == -1) {
+       status = open_socket_out(dest_ss, TCP_SMB_PORT, LONG_CONNECT_TIMEOUT, &s);
+       if (!NT_STATUS_IS_OK(status)) {
                char addr[INET6_ADDRSTRLEN];
                if (dest_ss) {
                        print_sockaddr(addr, sizeof(addr), dest_ss);
                }
 
                d_printf("Unable to connect to %s (%s)\n",
-                        dest_ss?addr:"NULL",strerror(errno));
+                        dest_ss?addr:"NULL", nt_errstr(status));
                exit(1);
        }
 
        while (c != -1 || s != -1) {
-               fd_set fds;
-               int num;
-               
-               FD_ZERO(&fds);
-               if (s != -1) FD_SET(s, &fds);
-               if (c != -1) FD_SET(c, &fds);
-
-               num = sys_select_intr(MAX(s+1, c+1),&fds,NULL,NULL,NULL);
-               if (num <= 0) continue;
-               
-               if (c != -1 && FD_ISSET(c, &fds)) {
+               struct pollfd fds[2];
+               int num_fds, ret;
+
+               memset(fds, 0, sizeof(struct pollfd) * 2);
+               fds[0].fd = -1;
+               fds[1].fd = -1;
+               num_fds = 0;
+
+               if (s != -1) {
+                       fds[num_fds].fd = s;
+                       fds[num_fds].events = POLLIN|POLLHUP;
+                       num_fds += 1;
+               }
+               if (c != -1) {
+                       fds[num_fds].fd = c;
+                       fds[num_fds].events = POLLIN|POLLHUP;
+                       num_fds += 1;
+               }
+
+               ret = sys_poll_intr(fds, num_fds, -1);
+               if (ret <= 0) {
+                       continue;
+               }
+
+               /*
+                * find c in fds and see if it's readable
+                */
+               if ((c != -1) &&
+                   (((fds[0].fd == c)
+                     && (fds[0].revents & (POLLIN|POLLHUP|POLLERR))) ||
+                    ((fds[1].fd == c)
+                     && (fds[1].revents & (POLLIN|POLLHUP|POLLERR))))) {
                        size_t len;
                        if (!NT_STATUS_IS_OK(receive_smb_raw(
                                                        c, packet, sizeof(packet),
@@ -214,7 +239,15 @@ static void filter_child(int c, struct sockaddr_storage *dest_ss)
                                exit(1);
                        }                       
                }
-               if (s != -1 && FD_ISSET(s, &fds)) {
+
+               /*
+                * find s in fds and see if it's readable
+                */
+               if ((s != -1) &&
+                   (((fds[0].fd == s)
+                     && (fds[0].revents & (POLLIN|POLLHUP|POLLERR))) ||
+                    ((fds[1].fd == s)
+                     && (fds[1].revents & (POLLIN|POLLHUP|POLLERR))))) {
                        size_t len;
                        if (!NT_STATUS_IS_OK(receive_smb_raw(
                                                        s, packet, sizeof(packet),
@@ -245,8 +278,8 @@ static void start_filter(char *desthost)
        /* start listening on port 445 locally */
 
        zero_sockaddr(&my_ss);
-       s = open_socket_in(SOCK_STREAM, 445, 0, &my_ss, True);
-       
+       s = open_socket_in(SOCK_STREAM, TCP_SMB_PORT, 0, &my_ss, True);
+
        if (s == -1) {
                d_printf("bind failed\n");
                exit(1);
@@ -262,16 +295,12 @@ static void start_filter(char *desthost)
        }
 
        while (1) {
-               fd_set fds;
-               int num;
+               int num, revents;
                struct sockaddr_storage ss;
                socklen_t in_addrlen = sizeof(ss);
-               
-               FD_ZERO(&fds);
-               FD_SET(s, &fds);
 
-               num = sys_select_intr(s+1,&fds,NULL,NULL,NULL);
-               if (num > 0) {
+               num = poll_intr_one_fd(s, POLLIN|POLLHUP, -1, &revents);
+               if ((num > 0) && (revents & (POLLIN|POLLHUP|POLLERR))) {
                        c = accept(s, (struct sockaddr *)&ss, &in_addrlen);
                        if (c != -1) {
                                if (fork() == 0) {
@@ -309,7 +338,7 @@ int main(int argc, char *argv[])
                netbiosname = argv[2];
        }
 
-       if (!lp_load(configfile,True,False,False,True)) {
+       if (!lp_load_global(configfile)) {
                d_printf("Unable to load config file\n");
        }