/* This file is part of GNU Pies testsuite.
Copyright (C) 2019-2020 Sergey Poznyakoff
GNU Pies is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3, or (at your option)
any later version.
GNU Pies is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with GNU Pies. If not, see . */
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include "libpies.h"
char const *progname;
void
usage (FILE *fp, int status)
{
fprintf (fp, "usage: %s [-s SOCKET] COMMAND ARGS...\n", progname);
fprintf (fp, "Test tool for accept and pass-fd pies components.\n");
fprintf (fp, "Listens on the file descriptor, either 0 or obtained from SOCKET.\n");
fprintf (fp, "For each connection, execs COMMAND ARGS as a separate process.\n");
exit (status);
}
static int
listen_socket (char const *socket_name)
{
struct sockaddr_un addr;
int sockfd;
if (strlen (socket_name) > sizeof addr.sun_path)
{
fprintf (stderr, "%s: UNIX socket name too long\n", progname);
return -1;
}
addr.sun_family = AF_UNIX;
strcpy (addr.sun_path, socket_name);
sockfd = socket (PF_UNIX, SOCK_STREAM, 0);
if (sockfd == -1)
{
perror ("socket");
exit (1);
}
umask (0117);
if (bind (sockfd, (struct sockaddr *) &addr, sizeof (addr)) < 0)
{
perror ("bind");
exit (1);
}
if (listen (sockfd, 8) < 0)
{
perror ("listen");
exit (1);
}
return sockfd;
}
static int
read_fd (int fd)
{
struct msghdr msg;
struct iovec iov[1];
char base[1];
#if HAVE_STRUCT_MSGHDR_MSG_CONTROL
union
{
struct cmsghdr cm;
char control[CMSG_SPACE (sizeof (int))];
} control_un;
struct cmsghdr *cmptr;
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof (control_un.control);
#elif HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS
int newfd;
msg.msg_accrights = (caddr_t) &newfd;
msg.msg_accrightslen = sizeof (int);
#else
fprintf (stderr, "no way to get fd\n");
exit (77);
#endif
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = base;
iov[0].iov_len = sizeof (base);
msg.msg_iov = iov;
msg.msg_iovlen = 1;
if (recvmsg (fd, &msg, 0) > 0)
{
#if HAVE_STRUCT_MSGHDR_MSG_CONTROL
if ((cmptr = CMSG_FIRSTHDR (&msg)) != NULL
&& cmptr->cmsg_len == CMSG_LEN (sizeof (int))
&& cmptr->cmsg_level == SOL_SOCKET
&& cmptr->cmsg_type == SCM_RIGHTS)
return *((int*) CMSG_DATA (cmptr));
#elif HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS
if (msg.msg_accrightslen == sizeof (int))
return newfd;
#endif
}
return -1;
}
static int
get_fd (int lfd)
{
int sfd, fd = accept (lfd, NULL, NULL);
if (fd == -1)
{
perror ("accept");
exit (1);
}
sfd = read_fd (fd);
close (fd);
return sfd;
}
static void
sigchld (int sig)
{
pid_t pid;
while ((pid = waitpid ((pid_t)-1, NULL, WNOHANG)) >= 0)
;
signal (sig, sigchld);
}
static void
sigquit (int sig)
{
kill (0, sig);
exit (0);
}
int
main (int argc, char **argv)
{
int c;
int fd;
char *socket_name = NULL;
progname = argv[0];
while ((c = getopt (argc, argv, "hs:")) != EOF)
{
switch (c)
{
case 'h':
usage (stdout, 0);
break;
case 's':
socket_name = optarg;
break;
default:
exit (64);
}
}
argc -= optind;
argv += optind;
if (argc == 0)
usage (stderr, 64);
if (socket_name)
{
int sfd = listen_socket (socket_name);
fd = get_fd (sfd);
close (sfd);
}
else
fd = 0;
signal (SIGCHLD, sigchld);
signal (SIGTERM, sigquit);
signal (SIGHUP, sigquit);
signal (SIGINT, sigquit);
signal (SIGQUIT, sigquit);
while (1)
{
int cfd = accept (fd, NULL, NULL);
if (cfd == -1)
{
perror ("accept");
exit (1);
}
pid_t pid = fork ();
if (pid == 0)
{
int i;
for (i = getmaxfd (); i >= 0; i--)
if (i != cfd)
close (i);
if (cfd != 0)
dup2 (cfd, 0);
if (cfd != 1)
dup2 (cfd, 1);
if (cfd != 2)
dup2 (cfd, 2);
if (cfd > 2)
close (cfd);
execvp (argv[0], argv);
exit (127);
}
if (pid == -1)
{
perror ("fork");
}
close (cfd);
}
return 0;
}