/* This file is part of Pies.
Copyright (C) 2007, 2008, 2009 Sergey Poznyakoff
Pies is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3, or (at your option)
any later version.
Pies is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Pies. If not, see . */
#include "pies.h"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
static void
switch_eids (uid_t *puid, gid_t *pgid, mode_t *pumask)
{
uid_t ouid = geteuid ();
gid_t ogid = getegid ();
mode_t omask = umask (*pumask);
if (setegid (*pgid))
logmsg (LOG_ERR, _("Cannot switch to EGID %lu: %s"),
(unsigned long) *pgid, strerror (errno));
if (seteuid (*puid))
logmsg (LOG_ERR, _("Cannot switch to EUID %lu: %s"),
(unsigned long) *puid, strerror (errno));
*puid = ouid;
*pgid = ogid;
*pumask = omask;
}
int
create_socket (struct pies_url *url, const char *user, mode_t umaskval)
{
int rc;
int fd;
union
{
struct sockaddr sa;
struct sockaddr_in s_in;
struct sockaddr_un s_un;
} addr;
socklen_t socklen;
uid_t uid = 0;
gid_t gid = 0;
int switch_back;
if (strcmp (url->proto, "unix") == 0
|| strcmp (url->proto, "file") == 0
|| strcmp (url->proto, "socket") == 0)
{
struct stat st;
const char *group = NULL;
user = url->user;
if (url->argc)
{
size_t i;
for (i = 0; i < url->argc; i++)
{
const char *arg = url->argv[i];
size_t len = strcspn (arg, "=");
if (strncmp (arg, "user", len) == 0)
user = arg + len + 1;
else if (strncmp (arg, "group", len) == 0)
group = arg + len + 1;
else if (strncmp (arg, "umask", len) == 0)
{
char *p;
unsigned long n = strtoul (arg + len + 1, &p, 8);
if (*p)
logmsg (LOG_ERR, _("%s: invalid octal number (%s)"),
url->string, arg + len + 1);
else if (n & ~0777)
logmsg (LOG_ERR, _("%s: invalid umask (%s)"),
url->string, arg + len + 1);
else
umaskval = n & 0777;
}
else if (strncmp (arg, "mode", len) == 0)
{
char *p;
unsigned long n = strtoul (arg + len + 1, &p, 8);
if (*p)
logmsg (LOG_ERR, _("%s: invalid octal number (%s)"),
url->string, arg + len + 1);
else if (n & ~0777)
logmsg (LOG_ERR, _("%s: invalid mode (%s)"),
url->string, arg + len + 1);
else
umaskval = 0777 & ~n;
}
}
}
if (user)
{
struct passwd *pw = getpwnam (user);
if (!pw)
{
logmsg (LOG_ERR, _("no such user: %s"), user);
return -1;
}
uid = pw->pw_uid;
gid = pw->pw_gid;
}
if (group)
{
struct group *grp = getgrnam (group);
if (!grp)
{
logmsg (LOG_ERR, _("no such group: %s"), user);
return -1;
}
gid = grp->gr_gid;
}
if (strlen (url->path) > sizeof addr.s_un.sun_path)
{
errno = EINVAL;
logmsg (LOG_ERR, _("%s: UNIX socket name too long"), url->path);
return -1;
}
addr.sa.sa_family = PF_UNIX;
socklen = sizeof (addr.s_un);
strcpy (addr.s_un.sun_path, url->path);
if (stat (url->path, &st))
{
if (errno != ENOENT)
{
logmsg (LOG_ERR, _("%s: cannot stat socket: %s"),
url->string, strerror (errno));
return -1;
}
}
else
{
/* FIXME: Check permissions? */
if (!S_ISSOCK (st.st_mode))
{
logmsg (LOG_ERR, _("%s: not a socket"), url->string);
return -1;
}
if (/*rmsocket && */ unlink (url->path))
{
logmsg (LOG_ERR, _("%s: cannot unlink: %s"),
url->path, strerror (errno));
return -1;
}
}
}
else if (strcmp (url->proto, "inet") == 0)
{
const char *host = url->host;
short port = url->port;
addr.sa.sa_family = PF_INET;
socklen = sizeof (addr.s_in);
if (!host)
addr.s_in.sin_addr.s_addr = INADDR_ANY;
else
{
struct hostent *hp = gethostbyname (host);
if (!hp)
{
logmsg (LOG_ERR, _("%s: Unknown host name %s"),
url->string, host);
return -1;
}
addr.sa.sa_family = hp->h_addrtype;
switch (hp->h_addrtype)
{
case AF_INET:
memmove (&addr.s_in.sin_addr, hp->h_addr, 4);
addr.s_in.sin_port = htons (port);
break;
default:
logmsg (LOG_ERR, _("%s: unsupported address family"),
url->string);
return -1;
}
}
}
else
{
logmsg (LOG_ERR, "%s: unknown scheme", url->string);
return -1;
}
fd = socket (addr.sa.sa_family, SOCK_STREAM, 0);
if (fd == -1)
{
logmsg (LOG_ERR, _("%s: cannot create socket: %s"),
url->string, strerror (errno));
return -1;
}
rc = 1;
if (addr.sa.sa_family != PF_UNIX
&& setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (void *) &rc,
sizeof (rc)) == -1)
{
logmsg (LOG_ERR, _("%s: set reuseaddr failed (%s)"),
url->string, strerror (errno));
close (fd);
return -1;
}
if (uid || gid || umaskval)
{
switch_eids (&uid, &gid, &umaskval);
switch_back = 1;
}
else
switch_back = 0;
rc = bind (fd, &addr.sa, socklen);
if (switch_back)
switch_eids (&uid, &gid, &umaskval);
if (rc < 0)
{
logmsg (LOG_ERR, _("%s: cannot bind: %s"),
url->string, strerror (errno));
close (fd);
return -1;
}
return fd;
}
static int
pass_fd0 (int fd, int payload)
{
struct msghdr msg;
struct iovec iov[1];
#if HAVE_STRUCT_MSGHDR_MSG_CONTROL
# ifndef CMSG_LEN
# define CMSG_LEN(size) (sizeof(struct cmsghdr) + (size))
# endif /* ! CMSG_LEN */
# ifndef CMSG_SPACE
# define CMSG_SPACE(size) (sizeof(struct cmsghdr) + (size))
# endif /* ! CMSG_SPACE */
char control[CMSG_SPACE (sizeof (int))];
struct cmsghdr *cmptr;
msg.msg_control = (caddr_t) control;
msg.msg_controllen = CMSG_LEN (sizeof (int));
cmptr = CMSG_FIRSTHDR (&msg);
cmptr->cmsg_len = CMSG_LEN (sizeof(int));
cmptr->cmsg_level = SOL_SOCKET;
cmptr->cmsg_type = SCM_RIGHTS;
*((int *) CMSG_DATA (cmptr)) = payload;
#elif HAVE_STRUCT_MSGHDR_MSG_ACCRIGHTS
msg.msg_accrights = (caddr_t) &payload;
msg.msg_accrightslen = sizeof (int);
#else
logmsg (LOG_ERR, _("no way to send fd"));
return 1;
#endif /* HAVE_MSGHDR_MSG_CONTROL */
msg.msg_name = NULL;
msg.msg_namelen = 0;
iov[0].iov_base = "";
iov[0].iov_len = 1;
msg.msg_iov = iov;
msg.msg_iovlen = 1;
return sendmsg (fd, &msg, 0) == -1;
}
int
pass_fd (const char *socket_name, int fd, unsigned maxtime)
{
enum { fds_init, fds_open, fds_connected, fds_ready } state = fds_init;
static char *fds_descr[] = { "init", "open", "connected", "ready" };
time_t start = time (NULL);
int sockfd = -1;
int res = -1;
struct sockaddr_un addr;
if (strlen (socket_name) > sizeof addr.sun_path)
{
logmsg (LOG_ERR, _("%s: UNIX socket name too long"), socket_name);
return -1;
}
addr.sun_family = AF_UNIX;
strcpy (addr.sun_path, socket_name);
for (;;)
{
time_t now = time (NULL);
if (now - start > maxtime)
{
logmsg (LOG_ERR, _("pass-fd timed out in state %s"),
fds_descr[state]);
break;
}
if (state == fds_init)
{
struct stat st;
if (stat (socket_name, &st) == 0)
{
if (!S_ISSOCK (st.st_mode))
{
logmsg (LOG_ERR, _("%s: not a socket"), socket_name);
break;
}
sockfd = socket (PF_UNIX, SOCK_STREAM, 0);
if (sockfd == -1)
{
if (errno == EINTR)
continue;
logmsg (LOG_ERR, "socket: %s", strerror (errno));
break;
}
state = fds_open;
}
else if (errno != ENOENT)
{
logmsg (LOG_ERR, _("cannot stat %s: %s"),
socket_name, strerror (errno));
break;
}
}
if (state == fds_open)
{
if (connect (sockfd, (struct sockaddr *) &addr, sizeof (addr)))
{
switch (errno)
{
case EINTR:
case ECONNREFUSED:
case EAGAIN:
continue;
}
logmsg (LOG_ERR, _("%s: connect failed: %s"),
socket_name, strerror (errno));
break;
}
state = fds_connected;
}
if (state == fds_connected)
{
int rc;
fd_set fds;
struct timeval tv;
FD_ZERO (&fds);
FD_SET (sockfd, &fds);
tv.tv_usec = 0;
tv.tv_sec = maxtime - (now - start);
rc = select (sockfd + 1, NULL, &fds, NULL, &tv);
if (rc == 0)
continue;
if (rc < 0)
{
if (errno == EINTR)
continue;
logmsg (LOG_ERR, _("select failed: %s"), strerror (errno));
break;
}
state = fds_ready;
}
if (state == fds_ready)
{
res = pass_fd0 (sockfd, fd);
break;
}
}
if (sockfd >= 0)
close (sockfd);
return res;
}
fd_set listenset;
int fd_max;
int
register_listener (int fd)
{
if (listen (fd, 8) == -1)
{
logmsg (LOG_ERR, _("listen: %s"), strerror (errno));
return 1;
}
FD_SET (fd, &listenset);
if (fd > fd_max)
fd_max = fd;
return 0;
}
void
pies_pause ()
{
while (1)
{
fd_set rdset = listenset;
int rc = select (fd_max + 1, &rdset, NULL, NULL, NULL);
if (rc > 0)
{
int i;
for (i = 0; i <= fd_max; i++)
{
if (FD_ISSET (i, &rdset))
progman_accept (i);
}
}
else if (rc < 0)
{
if (errno != EINTR)
logmsg (LOG_ERR, "select: %s", strerror (errno));
break;
}
}
}