lib/util: add tfork()
[amitay/samba.git] / lib / util / tfork.c
1 /*
2    fork on steroids to avoid SIGCHLD and waitpid
3
4    Copyright (C) Stefan Metzmacher 2010
5    Copyright (C) Ralph Boehme 2017
6
7    This program is free software; you can redistribute it and/or modify
8    it under the terms of the GNU General Public License as published by
9    the Free Software Foundation; either version 3 of the License, or
10    (at your option) any later version.
11
12    This program is distributed in the hope that it will be useful,
13    but WITHOUT ANY WARRANTY; without even the implied warranty of
14    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15    GNU General Public License for more details.
16
17    You should have received a copy of the GNU General Public License
18    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 */
20
21 #include "replace.h"
22 #include "system/wait.h"
23 #include "system/filesys.h"
24 #include "lib/util/samba_util.h"
25 #include "lib/util/sys_rw.h"
26 #include "lib/util/tfork.h"
27 #include "lib/util/debug.h"
28
29 struct tfork_state {
30         void (*old_sig_chld)(int);
31         int status_pipe[2];
32         pid_t *parent;
33
34         pid_t level0_pid;
35         int level0_status;
36
37         pid_t level1_pid;
38         int level1_errno;
39
40         pid_t level2_pid;
41         int level2_errno;
42
43         pid_t level3_pid;
44 };
45
46 /*
47  * TODO: We should make this global thread local
48  */
49 static struct tfork_state *tfork_global;
50
51 static void tfork_sig_chld(int signum)
52 {
53         if (tfork_global->level1_pid > 0) {
54                 int ret = waitpid(tfork_global->level1_pid,
55                               &tfork_global->level0_status,
56                               WNOHANG);
57                 if (ret == tfork_global->level1_pid) {
58                         tfork_global->level1_pid = -1;
59                         return;
60                 }
61         }
62
63         /*
64          * Not our child, forward to old handler
65          */
66
67         if (tfork_global->old_sig_chld == SIG_IGN) {
68                 return;
69         }
70
71         if (tfork_global->old_sig_chld == SIG_DFL) {
72                 return;
73         }
74
75         tfork_global->old_sig_chld(signum);
76 }
77
78 static pid_t level2_fork_and_wait(int child_ready_fd)
79 {
80         int status;
81         ssize_t written;
82         pid_t pid;
83         int fd;
84         bool wait;
85
86         /*
87          * Child level 2.
88          *
89          * Do a final fork and if the tfork() caller passed a status_fd, wait
90          * for child3 and return its exit status via status_fd.
91          */
92
93         pid = fork();
94         if (pid == 0) {
95                 /*
96                  * Child level 3, this one finally returns from tfork() as child
97                  * with pid 0.
98                  *
99                  * Cleanup all ressources we allocated before returning.
100                  */
101                 close(child_ready_fd);
102                 close(tfork_global->status_pipe[1]);
103
104                 if (tfork_global->parent != NULL) {
105                         /*
106                          * we're in the child and return the level0 parent pid
107                          */
108                         *tfork_global->parent = tfork_global->level0_pid;
109                 }
110
111                 anonymous_shared_free(tfork_global);
112                 tfork_global = NULL;
113
114                 return 0;
115         }
116
117         tfork_global->level3_pid = pid;
118         if (tfork_global->level3_pid == -1) {
119                 tfork_global->level2_errno = errno;
120                 _exit(0);
121         }
122
123         sys_write(child_ready_fd, &(char){0}, 1);
124
125         if (tfork_global->status_pipe[1] == -1) {
126                 _exit(0);
127         }
128         wait = true;
129
130         /*
131          * We're going to stay around until child3 exits, so lets close all fds
132          * other then the pipe fd we may have inherited from the caller.
133          */
134         fd = dup2(tfork_global->status_pipe[1], 0);
135         if (fd == -1) {
136                 status = errno;
137                 kill(tfork_global->level3_pid, SIGKILL);
138                 wait = false;
139         }
140         closefrom(1);
141
142         while (wait) {
143                 int ret = waitpid(tfork_global->level3_pid, &status, 0);
144                 if (ret == -1) {
145                         if (errno == EINTR) {
146                                 continue;
147                         }
148                         status = errno;
149                 }
150                 break;
151         }
152
153         written = sys_write(fd, &status, sizeof(status));
154         if (written != sizeof(status)) {
155                 abort();
156         }
157
158         _exit(0);
159 }
160
161 pid_t tfork(int *status_fd, pid_t *parent)
162 {
163         int ret;
164         pid_t pid;
165         pid_t child;
166
167         tfork_global = (struct tfork_state *)
168                 anonymous_shared_allocate(sizeof(struct tfork_state));
169         if (tfork_global == NULL) {
170                 return -1;
171         }
172
173         tfork_global->parent = parent;
174         tfork_global->status_pipe[0] = -1;
175         tfork_global->status_pipe[1] = -1;
176
177         tfork_global->level0_pid = getpid();
178         tfork_global->level0_status = -1;
179         tfork_global->level1_pid = -1;
180         tfork_global->level1_errno = ECANCELED;
181         tfork_global->level2_pid = -1;
182         tfork_global->level2_errno = ECANCELED;
183         tfork_global->level3_pid = -1;
184
185         if (status_fd != NULL) {
186                 ret = pipe(&tfork_global->status_pipe[0]);
187                 if (ret != 0) {
188                         int saved_errno = errno;
189
190                         anonymous_shared_free(tfork_global);
191                         tfork_global = NULL;
192                         errno = saved_errno;
193                         return -1;
194                 }
195
196                 *status_fd = tfork_global->status_pipe[0];
197         }
198
199         /*
200          * We need to set our own signal handler to prevent any existing signal
201          * handler from reaping our child.
202          */
203         tfork_global->old_sig_chld = CatchSignal(SIGCHLD, tfork_sig_chld);
204
205         pid = fork();
206         if (pid == 0) {
207                 int level2_pipe[2];
208                 char c;
209                 ssize_t nread;
210
211                 /*
212                  * Child level 1.
213                  *
214                  * Restore SIGCHLD handler
215                  */
216                 CatchSignal(SIGCHLD, SIG_DFL);
217
218                 /*
219                  * Close read end of the signal pipe, we don't need it anymore
220                  * and don't want to leak it into childs.
221                  */
222                 if (tfork_global->status_pipe[0] != -1) {
223                         close(tfork_global->status_pipe[0]);
224                         tfork_global->status_pipe[0] = -1;
225                 }
226
227                 /*
228                  * Create a pipe for waiting for the child level 2 to finish
229                  * forking.
230                  */
231                 ret = pipe(&level2_pipe[0]);
232                 if (ret != 0) {
233                         tfork_global->level1_errno = errno;
234                         _exit(0);
235                 }
236
237                 pid = fork();
238                 if (pid == 0) {
239
240                         /*
241                          * Child level 2.
242                          */
243
244                         close(level2_pipe[0]);
245                         return level2_fork_and_wait(level2_pipe[1]);
246                 }
247
248                 tfork_global->level2_pid = pid;
249                 if (tfork_global->level2_pid == -1) {
250                         tfork_global->level1_errno = errno;
251                         _exit(0);
252                 }
253
254                 close(level2_pipe[1]);
255                 level2_pipe[1] = -1;
256
257                 nread = sys_read(level2_pipe[0], &c, 1);
258                 if (nread != 1) {
259                         abort();
260                 }
261                 _exit(0);
262         }
263
264         tfork_global->level1_pid = pid;
265         if (tfork_global->level1_pid == -1) {
266                 int saved_errno = errno;
267
268                 anonymous_shared_free(tfork_global);
269                 tfork_global = NULL;
270                 errno = saved_errno;
271                 return -1;
272         }
273
274         /*
275          * By using the helper variable pid we avoid a TOCTOU with the signal
276          * handler that will set tfork_global->level1_pid to -1 (which would
277          * cause waitpid() to block waiting for another exitted child).
278          *
279          * We can't avoid the race waiting for pid twice (in the signal handler
280          * and then again here in the while loop), but we must avoid waiting for
281          * -1 and this does the trick.
282          */
283         pid = tfork_global->level1_pid;
284
285         while (tfork_global->level1_pid != -1) {
286                 ret = waitpid(pid, &tfork_global->level0_status, 0);
287                 if (ret == -1 && errno == EINTR) {
288                         continue;
289                 }
290
291                 break;
292         }
293
294         CatchSignal(SIGCHLD, tfork_global->old_sig_chld);
295
296         if (tfork_global->level0_status != 0) {
297                 anonymous_shared_free(tfork_global);
298                 tfork_global = NULL;
299                 errno = ECHILD;
300                 return -1;
301         }
302
303         if (tfork_global->level2_pid == -1) {
304                 int saved_errno = tfork_global->level1_errno;
305
306                 anonymous_shared_free(tfork_global);
307                 tfork_global = NULL;
308                 errno = saved_errno;
309                 return -1;
310         }
311
312         if (tfork_global->level3_pid == -1) {
313                 int saved_errno = tfork_global->level2_errno;
314
315                 anonymous_shared_free(tfork_global);
316                 tfork_global = NULL;
317                 errno = saved_errno;
318                 return -1;
319         }
320
321         child = tfork_global->level3_pid;
322         anonymous_shared_free(tfork_global);
323         tfork_global = NULL;
324
325         return child;
326 }