Prevent the reading of another message before the end of the current one.
authorWayne Davison <wayned@samba.org>
Sun, 13 Dec 2009 05:53:19 +0000 (21:53 -0800)
committerWayne Davison <wayned@samba.org>
Sun, 13 Dec 2009 05:54:52 +0000 (21:54 -0800)
io.c

diff --git a/io.c b/io.c
index d0ddaf5..2c162a6 100644 (file)
--- a/io.c
+++ b/io.c
@@ -84,7 +84,7 @@ static struct {
        xbuf in, out, msg;
        int in_fd;
        int out_fd; /* Both "out" and "msg" go to this fd. */
-       BOOL in_multiplexed;
+       int in_multiplexed;
        unsigned out_empty_len;
        size_t raw_data_header_pos;      /* in the out xbuf */
        size_t raw_flushing_ends_before; /* in the out xbuf */
@@ -127,7 +127,8 @@ static char int_byte_extra[64] = {
 #define IOBUF_WAS_REDUCED(siz) ((siz) & 0xFF)
 #define IOBUF_RESTORE_SIZE(siz) (((siz) | 0xFF) + 1)
 
-#define IN_MULTIPLEXED (iobuf.in_multiplexed)
+#define IN_MULTIPLEXED (iobuf.in_multiplexed != 0)
+#define IN_MULTIPLEXED_AND_READY (iobuf.in_multiplexed > 0)
 #define OUT_MULTIPLEXED (iobuf.out_empty_len != 0)
 
 #define PIO_NEED_INPUT (1<<0) /* The *_NEED_* flags are mutually exclusive. */
@@ -826,7 +827,7 @@ static char *perform_io(size_t needed, int flags)
 
                /* We need to help prevent deadlock by doing what reading
                 * we can whenever we are here trying to write. */
-               if (IN_MULTIPLEXED && !(flags & PIO_NEED_INPUT)) {
+               if (IN_MULTIPLEXED_AND_READY && !(flags & PIO_NEED_INPUT)) {
                        while (!iobuf.raw_input_ends_before && iobuf.in.len > 512)
                                read_a_msg();
                        if (flist_receiving_enabled && iobuf.in.len > 512)
@@ -1351,12 +1352,18 @@ static void read_a_msg(void)
        int tag, val;
        size_t msg_bytes;
 
+       /* This ensures that perform_io() does not try to do any message reading
+        * until we've read all of the data for this message.  We should also
+        * try to avoid calling things that will cause data to be written via
+        * perform_io() prior to this being reset to 1. */
+       iobuf.in_multiplexed = -1;
+
        tag = raw_read_int();
 
        msg_bytes = tag & 0xFFFFFF;
        tag = (tag >> 24) - MPLEX_BASE;
 
-       if (DEBUG_GTE(IO, 1) && (msgs2stderr || tag != MSG_INFO))
+       if (DEBUG_GTE(IO, 1) && msgs2stderr)
                rprintf(FINFO, "[%s] got msg=%d, len=%ld\n", who_am_i(), (int)tag, (long)msg_bytes);
 
        switch (tag) {
@@ -1368,21 +1375,26 @@ static void read_a_msg(void)
                 * which case the gradual reading of the input stream will
                 * cause this value to decrease and eventually become real. */
                iobuf.raw_input_ends_before = iobuf.in.pos + msg_bytes;
+               iobuf.in_multiplexed = 1;
                break;
        case MSG_STATS:
                if (msg_bytes != sizeof stats.total_read || !am_generator)
                        goto invalid_msg;
                raw_read_buf((char*)&stats.total_read, sizeof stats.total_read);
+               iobuf.in_multiplexed = 1;
                break;
        case MSG_REDO:
                if (msg_bytes != 4 || !am_generator)
                        goto invalid_msg;
-               got_flist_entry_status(FES_REDO, raw_read_int());
+               val = raw_read_int();
+               iobuf.in_multiplexed = 1;
+               got_flist_entry_status(FES_REDO, val);
                break;
        case MSG_IO_ERROR:
                if (msg_bytes != 4 || am_sender)
                        goto invalid_msg;
                val = raw_read_int();
+               iobuf.in_multiplexed = 1;
                io_error |= val;
                if (!am_generator)
                        send_msg_int(MSG_IO_ERROR, val);
@@ -1391,6 +1403,7 @@ static void read_a_msg(void)
                if (msg_bytes != 4 || am_server || am_generator)
                        goto invalid_msg;
                val = raw_read_int();
+               iobuf.in_multiplexed = 1;
                if (!io_timeout || io_timeout > val) {
                        if (INFO_GTE(MISC, 2))
                                rprintf(FINFO, "Setting --timeout=%d to match server\n", val);
@@ -1400,12 +1413,14 @@ static void read_a_msg(void)
        case MSG_NOOP:
                if (am_sender)
                        maybe_send_keepalive();
+               iobuf.in_multiplexed = 1;
                break;
        case MSG_DELETED:
                if (msg_bytes >= sizeof data)
                        goto overflow;
                if (am_generator) {
                        raw_read_buf(data, msg_bytes);
+                       iobuf.in_multiplexed = 1;
                        send_msg(MSG_DELETED, data, msg_bytes, 1);
                        break;
                }
@@ -1444,6 +1459,7 @@ static void read_a_msg(void)
                } else
 #endif
                        raw_read_buf(data, msg_bytes);
+               iobuf.in_multiplexed = 1;
                /* A directory name was sent with the trailing null */
                if (msg_bytes > 0 && !data[msg_bytes-1])
                        log_delete(data, S_IFDIR);
@@ -1461,6 +1477,7 @@ static void read_a_msg(void)
                        exit_cleanup(RERR_STREAMIO);
                }
                val = raw_read_int();
+               iobuf.in_multiplexed = 1;
                if (am_generator)
                        got_flist_entry_status(FES_SUCCESS, val);
                else
@@ -1470,6 +1487,7 @@ static void read_a_msg(void)
                if (msg_bytes != 4)
                        goto invalid_msg;
                val = raw_read_int();
+               iobuf.in_multiplexed = 1;
                if (am_generator)
                        got_flist_entry_status(FES_NO_SEND, val);
                else
@@ -1497,6 +1515,7 @@ static void read_a_msg(void)
                        exit_cleanup(RERR_STREAMIO);
                }
                raw_read_buf(data, msg_bytes);
+               iobuf.in_multiplexed = 1;
                rwrite((enum logcode)tag, data, msg_bytes, !am_generator);
                if (first_message) {
                        if (list_only && !am_sender && tag == 1 && msg_bytes < sizeof data) {
@@ -1507,6 +1526,13 @@ static void read_a_msg(void)
                }
                break;
        case MSG_ERROR_EXIT:
+               if (msg_bytes == 4)
+                       val = raw_read_int();
+               else if (msg_bytes == 0)
+                       val = 0;
+               else
+                       goto invalid_msg;
+               iobuf.in_multiplexed = 1;
                if (DEBUG_GTE(EXIT, 3))
                        rprintf(FINFO, "[%s] got MSG_ERROR_EXIT with %d bytes\n", who_am_i(), msg_bytes);
                if (msg_bytes == 0) {
@@ -1519,7 +1545,7 @@ static void read_a_msg(void)
                                io_flush(FULL_FLUSH);
                        }
                        val = 0;
-               } else if (msg_bytes == 4) {
+               } else {
                        val = raw_read_int();
                        if (protocol_version >= 31) {
                                if (am_generator) {
@@ -1536,8 +1562,7 @@ static void read_a_msg(void)
                                        send_msg(MSG_ERROR_EXIT, "", 0, 0);
                                }
                        }
-               } else
-                       goto invalid_msg;
+               }
                /* Send a negative linenum so that we don't end up
                 * with a duplicate exit message. */
                _exit_cleanup(val, __FILE__, 0 - __LINE__);
@@ -1546,11 +1571,13 @@ static void read_a_msg(void)
                        tag, who_am_i(), inc_recurse ? "/inc" : "");
                exit_cleanup(RERR_STREAMIO);
        }
+
+       assert(iobuf.in_multiplexed > 0);
 }
 
 static void drain_multiplex_messages(void)
 {
-       while (IN_MULTIPLEXED && iobuf.in.len) {
+       while (IN_MULTIPLEXED_AND_READY && iobuf.in.len) {
                if (iobuf.raw_input_ends_before) {
                        size_t raw_len = iobuf.raw_input_ends_before - iobuf.in.pos;
                        iobuf.raw_input_ends_before = 0;
@@ -2207,7 +2234,7 @@ void io_start_multiplex_in(int fd)
        if (msgs2stderr && DEBUG_GTE(IO, 2))
                rprintf(FINFO, "[%s] io_start_multiplex_in(%d)\n", who_am_i(), fd);
 
-       iobuf.in_multiplexed = True; /* See also IN_MULTIPLEXED */
+       iobuf.in_multiplexed = 1; /* See also IN_MULTIPLEXED */
        io_start_buffering_in(fd);
 }
 
@@ -2218,7 +2245,7 @@ int io_end_multiplex_in(int mode)
        if (msgs2stderr && DEBUG_GTE(IO, 2))
                rprintf(FINFO, "[%s] io_end_multiplex_in(mode=%d)\n", who_am_i(), mode);
 
-       iobuf.in_multiplexed = False;
+       iobuf.in_multiplexed = 0;
        if (mode == MPLX_SWITCHING)
                iobuf.raw_input_ends_before = 0;
        else