diff --git a/rc_scripts/iredadmin.debian b/rc_scripts/iredadmin.debian new file mode 100644 index 0000000..31962f7 --- /dev/null +++ b/rc_scripts/iredadmin.debian @@ -0,0 +1,106 @@ +#!/usr/bin/env bash +# Author: Zhang Huangbin (zhb@iredmail.org) + +### BEGIN INIT INFO +# Provides: api-server +# Required-Start: $network $syslog +# Required-Stop: $network $syslog +# Default-Start: 2 3 4 5 +# Default-Stop: 0 1 6 +# Short-Description: iredadmin instance +# Description: iredadmin +### END INIT INFO + +PROG='iredadmin' +PIDFILE='/var/run/iredadmin/iredadmin.pid' +UWSGI_INI_FILE='/opt/www/iredadmin/rc_scripts/uwsgi/debian.ini' + +check_status() { + # Usage: check_status pid_number + PID="${1}" + l=$(ps -p ${PID} | wc -l | awk '{print $1}') + if [ X"$l" == X"2" ]; then + echo "running" + else + echo "stopped" + fi +} + +start() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${PROG} is already running." + exit 0 + else + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + + unset s + fi + + mkdir /var/run/iredadmin 2>/dev/null + chown iredadmin:iredadmin /var/run/iredadmin + chmod 0755 /var/run/iredadmin + + echo "Starting ${PROG} ..." + uwsgi -d \ + --ini ${UWSGI_INI_FILE} \ + --pidfile ${PIDFILE} \ + --log-syslog +} + +stop() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "Stopping ${PROG} ..." + uwsgi --stop ${PIDFILE} + if [ X"$?" == X"0" ]; then + rm -f ${PIDFILE} >/dev/null 2>&1 + rm -rf /var/run/iredadmin + else + echo -e "\t\t[ FAILED ]" + fi + else + echo "${PROG} is already stopped." + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + else + echo "${PROG} is already stopped." + fi + unset s +} + +status() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${PROG} is running." + exit 0 + else + echo "${PROG} is stopped." + exit 1 + fi + else + echo "${PROG} is stopped." + exit 3 + fi +} + +case "$1" in + start) start ;; + stop) stop ;; + status) status ;; + restart) stop && start ;; + *) + echo $"Usage: $0 {start|stop|restart|status}" + RETVAL=1 + ;; +esac diff --git a/rc_scripts/iredadmin.freebsd b/rc_scripts/iredadmin.freebsd new file mode 100644 index 0000000..a1c7d7a --- /dev/null +++ b/rc_scripts/iredadmin.freebsd @@ -0,0 +1,110 @@ +#!/bin/sh + +# Author: Zhang Huangbin + +# PROVIDE: iredadmin +# REQUIRE: DAEMON +# KEYWORD: shutdown + +. /etc/rc.subr +name='iredadmin' +rcvar=`set_rcvar_obsolete` +start_precmd="iredadmin_precmd" + +RUN_DIR='/var/run/iredadmin' +PIDFILE="${RUN_DIR}/iredadmin.pid" +UWSGI_INI_FILE='/opt/www/iredadmin/rc_scripts/uwsgi/freebsd.ini' + +PATH="/usr/local/bin:/usr/local/sbin:$PATH" + +iredadmin_precmd() { + /usr/bin/install -m 0644 -o iredadmin -g iredadmin -d ${RUN_DIR} +} + +check_status() { + # Usage: check_status pid_number + PID="${1}" + l=$(ps -p ${PID} | wc -l | awk '{print $1}') + if [ X"$l" == X"2" ]; then + echo "running" + else + echo "stopped" + fi +} + +start() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${name} is already running." + else + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + + unset s + fi + + /bin/mkdir $(dirname ${PIDFILE}) 2>/dev/null + /usr/sbin/chown iredadmin:iredadmin $(dirname ${PIDFILE}) + + echo "Starting ${name}." + uwsgi --ini ${UWSGI_INI_FILE} \ + --pidfile ${PIDFILE} \ + --log-syslog \ + --daemonize /dev/null +} + +stop() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "Stopping ${name}." + uwsgi --stop ${PIDFILE} + if [ X"$?" == X"0" ]; then + rm -f ${PIDFILE} >/dev/null 2>&1 + else + echo -e "\t\t[ FAILED ]" + fi + else + echo "${name} is already stopped." + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + + unset s + else + echo "${name} is already stopped." + fi +} + +status() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${name} is running." + exit 0 + else + echo "${name} is stopped." + exit 1 + fi + + unset s + else + echo "${name} is stopped." + exit 3 + fi +} + +start_cmd="start" +stop_cmd="stop" +status_cmd="status" +restart_cmd="stop && sleep 2 && start" + +command="start" +load_rc_config ${name} +run_rc_command "$1" diff --git a/rc_scripts/iredadmin.openbsd b/rc_scripts/iredadmin.openbsd new file mode 100644 index 0000000..ed2588e --- /dev/null +++ b/rc_scripts/iredadmin.openbsd @@ -0,0 +1,23 @@ +#!/bin/ksh +# Author: Zhang Huangbin +# Purpose: Start/stop iRedAdmin uwsgi instance. + +RUN_DIR='/var/run/iredadmin' +PID_FILE="${RUN_DIR}/iredadmin.pid" +UWSGI_INI_FILE='/opt/www/iredadmin/rc_scripts/uwsgi/openbsd.ini' + +daemon="/usr/local/bin/uwsgi --ini ${UWSGI_INI_FILE} --log-syslog --pidfile ${PID_FILE} --daemonize /dev/null" +daemon_user='iredadmin' +daemon_group='iredadmin' + +. /etc/rc.d/rc.subr + +rc_pre() { + install -d -o ${daemon_user} -g ${daemon_group} -m 0775 ${RUN_DIR} +} + +rc_stop() { + kill -INT `cat ${PID_FILE}` +} + +rc_cmd $1 diff --git a/rc_scripts/iredadmin.rhel b/rc_scripts/iredadmin.rhel new file mode 100644 index 0000000..d48de23 --- /dev/null +++ b/rc_scripts/iredadmin.rhel @@ -0,0 +1,104 @@ +#!/usr/bin/env bash + +# Author: Zhang Huangbin (zhb@iredmail.org) + +### BEGIN INIT INFO +# chkconfig: - 99 99 +# description: iredadmin instance +# processname: iredadmin +### END INIT INFO + +PROG='iredadmin' +BINPATH='/opt/www/iredadmin/iredadmin.py' +PIDFILE='/var/run/iredadmin/iredadmin.pid' +UWSGI_INI_FILE='/opt/www/iredadmin/rc_scripts/uwsgi/rhel.ini' + +check_status() { + # Usage: check_status pid_number + PID="${1}" + l=$(ps -p ${PID} | wc -l | awk '{print $1}') + if [ X"$l" == X"2" ]; then + echo "running" + else + echo "stopped" + fi +} + +start() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${PROG} is already running." + else + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + fi + + unset s + + mkdir /var/run/iredadmin 2>/dev/null + chown iredadmin:iredadmin /var/run/iredadmin + chmod 0755 /var/run/iredadmin + + echo "Starting ${PROG} ..." + uwsgi -d \ + --ini ${UWSGI_INI_FILE} \ + --pidfile ${PIDFILE} \ + --log-syslog +} + +stop() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "Stopping ${PROG} ..." + kill -9 ${PID} + if [ X"$?" == X"0" ]; then + rm -f ${PIDFILE} >/dev/null 2>&1 + rm -rf /var/run/iredadmin + else + echo -e "\t\t[ FAILED ]" + fi + else + echo "${PROG} is already stopped." + rm -f ${PIDFILE} >/dev/null 2>&1 + fi + else + echo "${PROG} is already stopped." + fi + + unset s +} + +status() { + if [ -f ${PIDFILE} ]; then + PID="$(cat ${PIDFILE})" + s="$(check_status ${PID})" + + if [ X"$s" == X"running" ]; then + echo "${PROG} is running." + exit 0 + else + echo "${PROG} is stopped." + exit 1 + fi + else + echo "${PROG} is stopped." + exit 3 + fi +} + +case "$1" in + start) start ;; + stop) stop ;; + status) status ;; + restart) stop && sleep 1 && start ;; + *) + echo $"Usage: $0 {start|stop|restart|status}" + RETVAL=1 + ;; +esac diff --git a/rc_scripts/systemd/debian.service b/rc_scripts/systemd/debian.service new file mode 100644 index 0000000..6388dae --- /dev/null +++ b/rc_scripts/systemd/debian.service @@ -0,0 +1,17 @@ +[Unit] +Description=iRedAdmin daemon service +After=network.target local-fs.target remote-fs.target + +[Service] +Type=simple +ExecStartPre=-/bin/mkdir -p /var/run/iredadmin +ExecStartPre=/bin/chown iredadmin:iredadmin /var/run/iredadmin +ExecStartPre=/bin/chmod 0755 /var/run/iredadmin +ExecStart=/usr/bin/uwsgi --ini /opt/www/iredadmin/rc_scripts/uwsgi/debian.ini --pidfile /var/run/iredadmin/iredadmin.pid +ExecStop=/usr/bin/uwsgi --stop /var/run/iredadmin/iredadmin.pid +ExecStopPost=/bin/rm -rf /var/run/iredadmin +KillSignal=SIGTERM +PrivateTmp=true + +[Install] +WantedBy=multi-user.target diff --git a/rc_scripts/systemd/rhel7.service b/rc_scripts/systemd/rhel7.service new file mode 100644 index 0000000..f2fc18c --- /dev/null +++ b/rc_scripts/systemd/rhel7.service @@ -0,0 +1,18 @@ +[Unit] +Description=iRedAdmin daemon service +After=network.target local-fs.target remote-fs.target + +[Service] +Type=simple +ExecStartPre=-/usr/bin/mkdir /var/run/iredadmin +ExecStartPre=/usr/bin/chown iredadmin:iredadmin /var/run/iredadmin +ExecStartPre=/usr/bin/chmod 0755 /var/run/iredadmin +ExecStart=/usr/sbin/uwsgi --ini /opt/www/iredadmin/rc_scripts/uwsgi/rhel7.ini --pidfile /var/run/iredadmin/iredadmin.pid +ExecStop=/usr/sbin/uwsgi --stop /var/run/iredadmin/iredadmin.pid +ExecStopPost=/usr/bin/rm -rf /var/run/iredadmin +KillSignal=SIGTERM +TimeoutStopSec=5 +PrivateTmp=true + +[Install] +WantedBy=multi-user.target diff --git a/rc_scripts/systemd/rhel8.service b/rc_scripts/systemd/rhel8.service new file mode 100644 index 0000000..597fb1e --- /dev/null +++ b/rc_scripts/systemd/rhel8.service @@ -0,0 +1,18 @@ +[Unit] +Description=iRedAdmin daemon service +After=network.target local-fs.target remote-fs.target + +[Service] +Type=simple +ExecStartPre=-/usr/bin/mkdir /var/run/iredadmin +ExecStartPre=/usr/bin/chown iredadmin:iredadmin /var/run/iredadmin +ExecStartPre=/usr/bin/chmod 0755 /var/run/iredadmin +ExecStart=/usr/local/bin/uwsgi --ini /opt/www/iredadmin/rc_scripts/uwsgi/rhel8.ini --pidfile /var/run/iredadmin/iredadmin.pid +ExecStop=/usr/local/bin/uwsgi --stop /var/run/iredadmin/iredadmin.pid +ExecStopPost=/usr/bin/rm -rf /var/run/iredadmin +KillSignal=SIGTERM +TimeoutStopSec=5 +PrivateTmp=true + +[Install] +WantedBy=multi-user.target diff --git a/rc_scripts/systemd/rhel9.service b/rc_scripts/systemd/rhel9.service new file mode 100644 index 0000000..f7cd9b5 --- /dev/null +++ b/rc_scripts/systemd/rhel9.service @@ -0,0 +1,18 @@ +[Unit] +Description=iRedAdmin daemon service +After=network.target local-fs.target remote-fs.target + +[Service] +Type=simple +ExecStartPre=-/usr/bin/mkdir /var/run/iredadmin +ExecStartPre=/usr/bin/chown iredadmin:iredadmin /var/run/iredadmin +ExecStartPre=/usr/bin/chmod 0755 /var/run/iredadmin +ExecStart=/usr/sbin/uwsgi --ini /opt/www/iredadmin/rc_scripts/uwsgi/rhel9.ini --pidfile /var/run/iredadmin/iredadmin.pid +ExecStop=/usr/sbin/uwsgi --stop /var/run/iredadmin/iredadmin.pid +ExecStopPost=/usr/bin/rm -rf /var/run/iredadmin +KillSignal=SIGTERM +TimeoutStopSec=5 +PrivateTmp=true + +[Install] +WantedBy=multi-user.target diff --git a/rc_scripts/uwsgi/debian.ini b/rc_scripts/uwsgi/debian.ini new file mode 100644 index 0000000..a7efaab --- /dev/null +++ b/rc_scripts/uwsgi/debian.ini @@ -0,0 +1,17 @@ +[uwsgi] +plugins = python3,syslog +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +uwsgi-socket = 127.0.0.1:7791 + +uid = iredadmin +gid = iredadmin + +chdir = /opt/www/iredadmin +wsgi-file = iredadmin.py diff --git a/rc_scripts/uwsgi/freebsd.ini b/rc_scripts/uwsgi/freebsd.ini new file mode 100644 index 0000000..9cc4692 --- /dev/null +++ b/rc_scripts/uwsgi/freebsd.ini @@ -0,0 +1,20 @@ +[uwsgi] +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +# Log pid of master process +safe-pid = true +pidfile = /var/run/iredadmin/iredadmin.pid + +uwsgi-socket = 127.0.0.1:7791 + +uid = iredadmin +gid = iredadmin + +chdir = /usr/local/www/iredadmin +wsgi-file = iredadmin.py diff --git a/rc_scripts/uwsgi/openbsd.ini b/rc_scripts/uwsgi/openbsd.ini new file mode 100644 index 0000000..303652f --- /dev/null +++ b/rc_scripts/uwsgi/openbsd.ini @@ -0,0 +1,13 @@ +[uwsgi] +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +uwsgi-socket = 127.0.0.1:7791 + +chdir = /var/www/iredadmin +wsgi-file = iredadmin.py diff --git a/rc_scripts/uwsgi/rhel7.ini b/rc_scripts/uwsgi/rhel7.ini new file mode 100644 index 0000000..3ddfba9 --- /dev/null +++ b/rc_scripts/uwsgi/rhel7.ini @@ -0,0 +1,17 @@ +[uwsgi] +plugins = python36,syslog +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +uwsgi-socket = 127.0.0.1:7791 + +uid = iredadmin +gid = iredadmin + +chdir = /opt/www/iredadmin +wsgi-file = iredadmin.py diff --git a/rc_scripts/uwsgi/rhel8.ini b/rc_scripts/uwsgi/rhel8.ini new file mode 100644 index 0000000..a7efaab --- /dev/null +++ b/rc_scripts/uwsgi/rhel8.ini @@ -0,0 +1,17 @@ +[uwsgi] +plugins = python3,syslog +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +uwsgi-socket = 127.0.0.1:7791 + +uid = iredadmin +gid = iredadmin + +chdir = /opt/www/iredadmin +wsgi-file = iredadmin.py diff --git a/rc_scripts/uwsgi/rhel9.ini b/rc_scripts/uwsgi/rhel9.ini new file mode 100644 index 0000000..a7efaab --- /dev/null +++ b/rc_scripts/uwsgi/rhel9.ini @@ -0,0 +1,17 @@ +[uwsgi] +plugins = python3,syslog +master = true +vhost = true +enable-threads = true +processes = 5 +buffer-size = 8192 +logger = syslog:iredadmin,local5 +log-format = [%(addr)] %(method) %(uri) %(status) %(size) "%(referer)" + +uwsgi-socket = 127.0.0.1:7791 + +uid = iredadmin +gid = iredadmin + +chdir = /opt/www/iredadmin +wsgi-file = iredadmin.py diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 0000000..6dc923e --- /dev/null +++ b/tools/README.md @@ -0,0 +1,44 @@ +# Cron Jobs + +* dump_disclaimer.py + + Dump per-domain disclaimer which stored in LDAP or SQL database. + It's safe to execute it manually. + +* cleanup_amavisd_db.py + + Cleanup old records from Amavisd database. It's safe to execute it manually. + +* delete_mailboxes.py + + Delete mailboxes which are scheduled to be removed. The schedule date + was set while you removed the mail account with iRedAdmin(-Pro). + +# Utils + +* upgrade_iredadmin.sh + + Upgrade an old iRedAdmin-Pro or iRedAdmin open source edition to current + release. + +* update_mailbox_quota.py + + Update mailbox quota for one user (specified on command line) or bulk users + (read from a plain text file). + +* notify_quarantined_recipients.py + + Notify local recipients (via email) that they have emails quarantined on + server and not delivered to their mailbox. + +* convert_ini_to_py.sh + + Convert old iRedAdmin-Pro config file (.ini format) to the new one. + +* migrate_cluebringer_wblist_to_amavisd.py + + Migrate Cluebringer white/blacklists to Amavisd database, and, optionally, + delete them in Cluebringer database. + + Note: Don't forget to enable iRedAPD plugin `amavisd_wblist` in + `/opt/iredapd/settings.py`. diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/cleanup_amavisd_db.py b/tools/cleanup_amavisd_db.py new file mode 100644 index 0000000..90cae28 --- /dev/null +++ b/tools/cleanup_amavisd_db.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Remove old records in Amavisd database. + +# USAGE: +# +# 1: Make sure you have correct database settings in iRedAdmin config file +# 'settings.py' for Amavisd. +# +# 2: Make sure you have proper values for below two parameters: +# +# AMAVISD_REMOVE_MAILLOG_IN_DAYS = 7 +# AMAVISD_REMOVE_QUARANTINED_IN_DAYS = 7 +# +# Default values is defined in libs/default_settings.py, you can override +# them in settings.py. WARNING: DO NOT MODIFY libs/default_settings.py. +# +# 3: Test this script in command line directly, make sure no errors in output +# message. +# +# # python cleanup_amavisd_db.py +# +# 4: Setup a daily cron job to execute this script. For example: execute +# it daily at 1:30AM. +# +# 30 1 * * * python /path/to/cleanup_amavisd_db.py >/dev/null +# +# That's all. + +import os +import sys +import time +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from libs import iredutils +from tools import ira_tool_lib + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +if not (settings.amavisd_enable_logging or settings.amavisd_enable_quarantine): + sys.exit("Amavisd is not enabled. SKIP.") + +backend = settings.backend +logger.info('Backend: %s' % backend) +logger.info('SQL server: %s:%d' % (settings.amavisd_db_host, int(settings.amavisd_db_port))) + +db_settings = iredutils.get_settings_from_db(params=['amavisd_remove_quarantined_in_days', 'amavisd_remove_maillog_in_days']) +keep_quar_days = db_settings['amavisd_remove_quarantined_in_days'] +keep_inout_days = db_settings['amavisd_remove_maillog_in_days'] +query_size_limit = settings.AMAVISD_CLEANUP_QUERY_SIZE_LIMIT + +# SQL records in `quarantine` table reference to `msgs`. +if keep_quar_days > keep_inout_days: + keep_inout_days = keep_quar_days + +conn_amavisd = ira_tool_lib.get_db_conn('amavisd') + +if settings.backend in ['mysql', 'ldap']: + # Querying (SELECT) without locking. Require MySQL 5.0+ and InnoDB. + # + # Since we're dealing with sql records created days ago, no new records + # will be inserted with that date, it's safe to use dirty read. + logger.info('Enable dirty read for querying without locking SQL tables.') + try: + conn_amavisd.query('SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED') + except Exception as e: + logger.error('Cannot enable dirty read: %s' % repr(e)) + + +# Removing records from single table. +def remove_from_one_table(sql_table, index_column, removed_values): + total = len(removed_values) + + # Delete how many records each time + offset = query_size_limit + + if total: + loop_times = total / offset + if total % offset: + loop_times += 1 + + for i in range(int(loop_times)): + removing_values = removed_values[offset * i: offset * (i + 1)] + logger.info( + '\t[-] Deleting records: %d - %d (%s)' % (i * offset, i * offset + len(removing_values), time.ctime())) + conn_amavisd.delete(sql_table, + vars={'ids': removing_values}, + where='%s IN $ids' % index_column) + + +# Delete old quarantined mails from table 'msgs'. It will also +# delete records in table 'quarantine'. +logger.info('Delete quarantined mails which older than %d days' % keep_quar_days) +_now = int(time.time()) +_expire_seconds = _now - (keep_quar_days * 86400) +sql_where = """time_num < %d AND quar_type='Q'""" % _expire_seconds + +counter_msgs = 0 +while True: + qr = conn_amavisd.select('msgs', + what='mail_id', + where=sql_where, + limit=query_size_limit) + + if qr: + ids = [r.mail_id for r in qr] + _total = len(ids) + + logger.info('\t[-] Deleting records: %d - %d (%s)' % (counter_msgs + 1, counter_msgs + _total, time.ctime())) + + conn_amavisd.delete('msgs', vars={'ids': ids}, where='mail_id IN $ids') + conn_amavisd.delete('msgrcpt', vars={'ids': ids}, where='mail_id IN $ids') + + counter_msgs += len(ids) + else: + break + +logger.info('Delete incoming/outgoing emails which older than %d days' % keep_inout_days) + +_now = int(time.time()) +_expire_seconds = _now - (keep_inout_days * 86400) +sql_where = """time_num < %d AND (quar_type <> 'Q' OR quar_type IS NULL)""" % _expire_seconds + +# We experienced an issue with PostgreSQL, it always return an non-existing +# SQL record, and it causes endless loop. As a hack, we store all removed +# `mail_id` and compare new `mail_id` with this list. +_removed_ids = set() + +counter_msgrcpt = 0 +while True: + qr = conn_amavisd.select('msgs', + what='mail_id', + where=sql_where, + limit=query_size_limit) + + if qr: + ids = [iredutils.bytes2str(r.mail_id) for r in qr] + _total = len(ids) + + _removing_ids = list(set(ids) - set(_removed_ids)) + if not _removing_ids: + break + + logger.info( + '\t[-] Deleting records: %d - %d (%s)' % (counter_msgrcpt + 1, counter_msgrcpt + _total, time.ctime())) + + conn_amavisd.delete('msgs', vars={'ids': _removing_ids}, where='mail_id IN $ids') + conn_amavisd.delete('msgrcpt', vars={'ids': _removing_ids}, where='mail_id IN $ids') + + counter_msgrcpt += _total + _removed_ids.update(ids) + else: + break + +# delete unreferenced records from tables msgrcpt, quarantine and maddr +logger.info('Delete unreferenced records from table `msgrcpt`.') +conn_amavisd.query(''' + DELETE FROM msgrcpt + WHERE NOT EXISTS (SELECT 1 FROM msgs WHERE mail_id=msgrcpt.mail_id) +''') + +# +# Delete unreferenced records from table `quarantine`. +# +logger.info('Delete unreferenced records from table `quarantine`.') +msgs_mail_ids = set() +maddr_ids_in_use = set() +quar_mail_ids = set() + +qr = conn_amavisd.select('msgs', what='mail_id, sid') +for i in qr: + msgs_mail_ids.add(i.mail_id) + maddr_ids_in_use.add(i.sid) + +qr = conn_amavisd.select('quarantine', what='mail_id') +for i in qr: + quar_mail_ids.add(i.mail_id) + +invalid_quar_mail_ids = [i for i in quar_mail_ids if i not in msgs_mail_ids] +remove_from_one_table(sql_table='quarantine', + index_column='mail_id', + removed_values=invalid_quar_mail_ids) + +# +# Delete unreferenced records from table `maddr`. +# +logger.info('Delete unreferenced records from table `maddr`.') + +# Get all maddr.id +maddr_ids = set() +qr = conn_amavisd.select('maddr', what='id') +for i in qr: + maddr_ids.add(i.id) + +qr = conn_amavisd.select('msgrcpt', what='rid') +for i in qr: + maddr_ids_in_use.add(i.rid) + +invalid_maddr_ids = [i for i in maddr_ids if i not in maddr_ids_in_use] +remove_from_one_table(sql_table='maddr', + index_column='id', + removed_values=invalid_maddr_ids) + +# +# Delete unreferenced records from table `mailaddr`. +# +logger.info('Delete unreferenced records from table `mailaddr`.') + +# Get all `mailaddr.id` +mailaddr_ids = set() +qr = conn_amavisd.select('mailaddr', what='id') +for i in qr: + mailaddr_ids.add(i.id) + +# Get all `wblist.sid` and `outbound_wblist.rid` (both refer to `mailaddr.id`) +wblist_ids = set() + +qr = conn_amavisd.select('wblist', what='sid') +for i in qr: + wblist_ids.add(i.sid) + +try: + qr = conn_amavisd.select('outbound_wblist', what='rid') + for i in qr: + wblist_ids.add(i.rid) +except: + # No outbound_wblist table + pass + +invalid_mailaddr_ids = [i for i in mailaddr_ids if i not in wblist_ids] +remove_from_one_table(sql_table='mailaddr', + index_column='id', + removed_values=invalid_mailaddr_ids) + +logger.info('') +logger.info('Remained records:') +logger.info('') +logger.info(' `msgs`: %-7.d' % len(msgs_mail_ids)) +logger.info('`quarantine`: %-7.d' % (len(quar_mail_ids) - len(invalid_quar_mail_ids))) +logger.info(' `maddr`: %-7.d' % (len(maddr_ids) - len(invalid_maddr_ids))) +logger.info(' `mailaddr`: %-7.d' % (len(mailaddr_ids) - len(invalid_mailaddr_ids))) + +if counter_msgs \ + or counter_msgrcpt \ + or invalid_quar_mail_ids \ + or invalid_maddr_ids \ + or invalid_mailaddr_ids: + msg = 'Removed records: ' + msg += '%d in msgs, ' % counter_msgs + msg += '%d in msgrcpt, ' % counter_msgrcpt + msg += '%d in quarantine, ' % len(invalid_quar_mail_ids) + msg += '%d in maddr, ' % len(invalid_maddr_ids) + msg += '%d in mailaddr.' % len(invalid_mailaddr_ids) + + ira_tool_lib.log_to_iredadmin(msg, admin='cleanup_amavisd_db', event='cleanup_db') + logger.info('Log cleanup status.') diff --git a/tools/cleanup_db.py b/tools/cleanup_db.py new file mode 100644 index 0000000..f775b94 --- /dev/null +++ b/tools/cleanup_db.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Remove old records in iRedAdmin SQL database. + +# USAGE: +# +# 1: Make sure you have proper values for below two parameters: +# +# IREDADMIN_LOG_KEPT_DAYS = 30 +# +# Default values is defined in libs/default_settings.py, you can override +# them in settings.py. WARNING: DO NOT MODIFY libs/default_settings.py. +# +# 2: Test this script in command line directly, make sure no errors in output +# message. +# +# # python cleanup_db.py +# +# 3: Setup a daily cron job to execute this script. For example: execute +# it daily at 1:30AM. +# +# 30 1 * * * python /path/to/cleanup_db.py >/dev/null +# +# That's all. + +import os +import sys +import time +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools.ira_tool_lib import debug, logger, sql_dbn, get_db_conn, sql_count_id + +web.config.debug = debug + +backend = settings.backend +logger.info('Backend: %s' % backend) +logger.info('SQL server: %s:%d' % (settings.iredadmin_db_host, int(settings.iredadmin_db_port))) + +query_size_limit = 100 + +conn_iredadmin = get_db_conn('iredadmin') + +# +# iredadmin.log +# +_days = settings.IREDADMIN_LOG_KEPT_DAYS +logger.info('Delete old admin activity log (> %d days)' % _days) + +if sql_dbn == 'mysql': + sql_where = """timestamp < DATE_SUB(NOW(), INTERVAL %d DAY)""" % _days +elif sql_dbn == 'postgres': + sql_where = """timestamp < CURRENT_TIMESTAMP - INTERVAL '%d DAYS'""" % _days +else: + logger.error('Invalid SQL backend: %s' % sql_dbn) + sys.exit() + +total_before = sql_count_id(conn_iredadmin, 'log') +conn_iredadmin.delete('log', where=sql_where) +total_after = sql_count_id(conn_iredadmin, 'log') +logger.info('\t- %d removed, %d left.' % (total_before - total_after, total_after)) + +# +# iredadmin.domain_ownership +# +_days = settings.DOMAIN_OWNERSHIP_EXPIRE_DAYS +logger.info('Delete old domain ownership verification records (> %d days)' % _days) + +total_before = sql_count_id(conn_iredadmin, 'domain_ownership') +conn_iredadmin.delete('domain_ownership', where="expire > %d" % (_days * 24 * 60 * 60)) +total_after = sql_count_id(conn_iredadmin, 'domain_ownership') +logger.info('\t- %d removed, %d left.' % (total_before - total_after, total_after)) + +# +# iredadmin.newsletter_subunsub_confirms +# +now = int(time.time()) +_hours = settings.NEWSLETTER_SUBSCRIPTION_REQUEST_KEEP_HOURS +logger.info('Delete expired newsletter subscription confirm tokens (> %d hours)' % _hours) + +total_before = sql_count_id(conn_iredadmin, 'newsletter_subunsub_confirms') +_expired = now - (_hours * 60 * 60) +conn_iredadmin.delete('newsletter_subunsub_confirms', where="expired <= %d" % _expired) +total_after = sql_count_id(conn_iredadmin, 'newsletter_subunsub_confirms') +logger.info('\t- %d removed, %d left.' % (total_before - total_after, total_after)) diff --git a/tools/delete_mailboxes.py b/tools/delete_mailboxes.py new file mode 100644 index 0000000..ee961db --- /dev/null +++ b/tools/delete_mailboxes.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Delete mailboxes which are scheduled to be removed. +# +# Notes: iRedAdmin will store maildir path of removed mail users in SQL table +# `iredadmin.deleted_mailboxes` (LDAP backends) or +# `vmail.deleted_mailboxes` (SQL backends). +# +# Usage: Either run this script manually, or run it with a daily cron job. +# +# # python3 delete_mailboxes.py +# +# Available arguments: +# +# * --delete-without-timestamp: +# +# [RISKY] If no timestamp string in maildir path, continue to delete it. +# +# With default iRedMail settings, maildir path will contain a timestamp +# like this: /u/s/e/username-2016.08.17.09.53.03/ +# (2016.08.17.09.53.03 is the timestamp), this way all created maildir +# paths are unique, even if you removed the user and recreate it with +# same mail address. +# +# Without timestamp in maildir path (e.g. /u/s/e/username/), +# if you removed a user and recreate it someday, this user will see old +# emails in old mailbox (because maildir path is same as old user's). So +# it becomes RISKY to remove the mailbox if no timestamp in maildir path. +# +# * --delete-null-date: +# +# Delete mailbox if SQL column `deleted_mailboxes.delete_date` is null. +# +# * --debug: print additional log + +import os +import sys +import time +import logging +import shutil +import pwd +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +from libs import iredutils +from tools import ira_tool_lib +import settings + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +if '--debug' in sys.argv: + logger.setLevel(logging.DEBUG) + +# Delete if `deleted_mailboxes.delete_date` is null. +delete_null_date = False +if '--delete-null-date' in sys.argv: + delete_null_date = True + +# Make sure there's a timestamp (yyyy.mm.dd.hh.mm.ss) in maildir path, +# otherwise it's too risky to remove this mailbox -- because the maildir +# could be reused by another user after old account was removed. +# +# - Safe to remove: /u/s/e/username-/ +# - Dangerous to remove: /u/s/e/username/ +delete_without_timestamp = False +if '--delete-without-timestamp' in sys.argv: + delete_without_timestamp = True + + +def delete_record(conn_deleted_mailboxes, rid): + try: + conn_deleted_mailboxes.delete('deleted_mailboxes', + vars={'id': rid}, + where='id=$id') + + return True, + except Exception as e: + return False, repr(e) + + +def delete_mailbox(conn_deleted_mailboxes, + record, + all_maildirs=None): + rid = record.id + username = str(record.username).lower() + timestamp = str(record.timestamp) + delete_date = record.delete_date + + maildir = record.maildir + maildir = maildir.replace('//', '/') # Remove duplicate '/' + + if delete_without_timestamp: + # Make sure no other mailbox is stored under the maildir. + if all_maildirs: + if not maildir.endswith('/'): + maildir += '/' + + for mdir in all_maildirs: + if mdir.startswith(maildir) or (mdir == maildir): + logger.error("<<< ABORT, CRITICAL >>> Trying to remove mailbox ({}) owned by user ({}), but there is another mailbox ({}) stored under this directory. Aborted.".format(maildir, username, mdir)) + return False + else: + _dir = maildir.rstrip('/') + + if len(_dir) <= 21: + # Why 21 chars: + # - 20 chars: "-". e.g. "-2014.03.26.15.07.25" + # - username contains at least 1 char + logger.error("<<< SKIP >>> Seems no timestamp in maildir path (%s), too risky to remove this mailbox." % maildir) + return False + + try: + # Extract timestamp string, make sure it's a valid time format. + ts = _dir[-19:] + time.strptime(ts, '%Y.%m.%d.%H.%M.%S') + except Exception as e: + logger.debug("<<< WARNING >>> Invalid or missing timestamp in maildir path (%s), skip." % maildir) + logger.debug("<<< WARNING >>> Error message: %s." % repr(e)) + return False + + # check maildir path + if os.path.isdir(maildir): + # Make sure directory is owned by vmail:vmail + _dir_stat = os.stat(maildir) + _dir_uid = _dir_stat.st_uid + + # Get uid/gid of vmail user + owner = pwd.getpwuid(_dir_uid).pw_name + if owner != 'vmail': + logger.error('<<< ERROR >> Directory is not owned by `vmail` user: uid -> {}, user -> {}.'.format(_dir_uid, owner)) + return False + + try: + msg = '[{}] {}.'.format(username, maildir) + msg += ' Account was deleted at {}.'.format(timestamp) + if delete_date: + msg += ' Mailbox was scheduled to be removed on {}.'.format(delete_date) + else: + msg += ' Mailbox was scheduled to be removed as soon as possible.' + + logger.info(msg) + + logger.info("Removing mailbox: {}".format(maildir)) + # Delete mailbox + shutil.rmtree(maildir) + + # Log this deletion. + ira_tool_lib.log_to_iredadmin(msg, + admin='cron_delete_mailboxes', + username=username, + event='delete_mailboxes') + except Exception as e: + logger.error('<<< ERROR >> while deleting mailbox ({} -> {}): {}'.format(username, maildir, repr(e))) + + # Delete record. + delete_record(conn_deleted_mailboxes=conn_deleted_mailboxes, rid=rid) + + +# Establish SQL connection. +try: + if settings.backend == 'ldap': + conn_deleted_mailboxes = ira_tool_lib.get_db_conn('iredadmin') + + from libs.ldaplib.core import LDAPWrap + _wrap = LDAPWrap() + conn_vmail = _wrap.conn + else: + conn_deleted_mailboxes = ira_tool_lib.get_db_conn('vmail') + conn_vmail = conn_deleted_mailboxes +except Exception as e: + sys.exit('<<< ERROR >>> Cannot connect to SQL database, aborted. Error: %s' % repr(e)) + +# Get paths of all maildirs. +sql_where = 'delete_date <= %s' % web.sqlquote(web.sqlliteral('NOW()')) +if delete_null_date: + sql_where = '(delete_date <= %s) OR (delete_date IS NULL)' % web.sqlquote(web.sqlliteral('NOW()')) + +qr_mailboxes = conn_deleted_mailboxes.select('deleted_mailboxes', where=sql_where) +if not qr_mailboxes: + logger.debug('No mailbox is scheduled to be removed.') + + if not delete_null_date: + logger.debug("To remove mailboxes without schedule date, please run this script with argument '--delete-null-date'.") + + if not delete_without_timestamp: + logger.debug("To remove mailboxes without timesamp in maildir path, please run this script with argument '--delete-without-timestamp'. [WARNING] It's RISKY.") + + sys.exit() + +# Get all maildir paths used by active mail users. +# +# To delete mailbox without timestamp in maildir path, we must make sure: +# - maildir is not used by some active user +# - no other mailbox is stored under this maildir path +# +# Q: Why query all maildir paths instead of querying SQL/LDAP directly? +# A: +# 1. LDAP attribute `homeDirectory` doesn't support `sub` (substring) index. +# 2. if maildir path contains duplicate '/', the validation will fail (not +# equal). +all_maildirs = [] +if delete_without_timestamp: + if settings.backend == 'ldap': + _qr = conn_vmail.search_s(settings.ldap_basedn, + 2, # ldap.SCOPE_SUBTREE + "(objectClass=mailUser)", + ['homeDirectory']) + for (_dn, _ldif) in _qr: + _ldif = iredutils.bytes2str(_ldif) + if 'homeDirectory' in _ldif: + _dir = _ldif['homeDirectory'][0].lower().replace('//', '/') + all_maildirs.append(_dir) + elif settings.backend in ['mysql', 'pgsql']: + # WARNING: always append '/' in returned maildir path. + _qr = conn_vmail.select('mailbox', + what="LOWER(CONCAT(storagebasedirectory, '/', storagenode, '/', maildir, '/')) AS maildir") + + all_maildirs = [str(i.maildir).replace('//', '/') for i in _qr] + +for r in list(qr_mailboxes): + delete_mailbox(conn_deleted_mailboxes=conn_deleted_mailboxes, + record=r, + all_maildirs=all_maildirs) diff --git a/tools/delete_sessions.py b/tools/delete_sessions.py new file mode 100644 index 0000000..84ed1f6 --- /dev/null +++ b/tools/delete_sessions.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Delete all records in SQL table "iredadmin.sessions" to force +# all admins to re-login. + +import os +import sys +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +from tools import ira_tool_lib + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +conn = ira_tool_lib.get_db_conn('iredadmin') + +logger.info('Delete all existing sessions to force all admins to re-login.') +conn.query('DELETE FROM sessions') diff --git a/tools/dump_disclaimer.py b/tools/dump_disclaimer.py new file mode 100644 index 0000000..cc6c96a --- /dev/null +++ b/tools/dump_disclaimer.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Updated: 2012.07.01 +# Purpose: Dump disclaimer text from OpenLDAP directory server or SQL servers. +# Requirements: iRedMail-0.5.0 or later releases +# +# Shipped within iRedAdmin-Pro: http://www.iredmail.org/admin_panel.html + +# USAGE: +# +# - Make sure you have correct backend related settings in iRedAdmin config +# file, settings.ini. +# +# - Test this script in command line directly, make sure no errors in output +# message. +# +# # python /path/to/dump_disclaimer.py /etc/postfix/disclaimer/ +# +# - Setup a cron job to execute this script daily. For example: execute +# this script at 2:01AM every day. +# +# 1 2 * * * python /path/to/dump_disclaimer.py /etc/postfix/disclaimer/ +# +# That's all. + +import os +import sys +import web + +web.config.debug = False + +# Directory used to store disclaimer files. +# Default directory is /etc/postfix/disclaimer/. +# Default disclaimer file name is [domain_name].txt +if len(sys.argv) != 2: + sys.exit('Error: Please specify a directory used to store disclaimer, default is /etc/postfix/disclaimer/') +else: + DISCLAIMER_DIR = sys.argv[1] + DISCLAIMER_FILE_EXT = '.txt' + +os.environ['LC_ALL'] = 'C' +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from libs import iredutils +from tools import ira_tool_lib +logger = ira_tool_lib.logger + +if settings.backend == 'ldap': + import ldap +elif settings.backend == 'mysql': + sql_dbn = 'mysql' +elif settings.backend == 'pgsql': + sql_dbn = 'postgres' + + +def write_disclaimer(text, filename, file_type='txt'): + # Write plain text + try: + f = open(filename, 'w') + + if file_type == 'html': + html = """

----------


""" + html += """

""" + text + """

""" + + f.write('\n' + html + '\n') + else: + f.write('\n---------\n' + text + '\n') + logger.info(" + %s" % filename) + f.close() + except Exception as e: + logger.info('<<< ERROR >>> %s' % str(e)) + + +def handle_disclaimer(domain, disclaimer_text): + """Dump or remove disclaimer text.""" + txt = os.path.join(DISCLAIMER_DIR, domain + '.txt') + html = os.path.join(DISCLAIMER_DIR, domain + '.html') + + if disclaimer_text: + write_disclaimer(text=disclaimer_text, + filename=txt, + file_type='txt') + + write_disclaimer(text=disclaimer_text, + filename=html, + file_type='html') + else: + # Remove old disclaimer file if no disclaimer setting + try: + for f in [txt, html]: + if os.path.isfile(f): + os.remove(f) + logger.info(" - Remove %s." % f) + except OSError: + pass + except Exception as e: + # Other errors. + logger.info("<<< ERROR >>> {}: {}.".format(domain, str(e))) + + +def dump_from_ldap(): + """Dump disclaimer text from LDAP server.""" + logger.info('Connecting to LDAP server') + conn = ldap.initialize(uri=settings.ldap_uri, + trace_level=0, + bytes_strictness='silent') + conn.set_option(ldap.OPT_PROTOCOL_VERSION, 3) + + logger.info('Binding with dn: %s' % settings.ldap_basedn) + conn.bind_s(settings.ldap_bind_dn, settings.ldap_bind_password) + + # Search and get disclaimer. + logger.info('Searching all domains') + qr = conn.search_s( + settings.ldap_basedn, + ldap.SCOPE_ONELEVEL, + '(objectClass=mailDomain)', + ['domainName', 'domainAliasName', 'disclaimer'], + ) + + logger.info('Dumping ...') + + for (_dn, _ldif) in qr: + _ldif = iredutils.bytes2str(_ldif) + + # Get domain names. + _domains = _ldif['domainName'] + _alias_domains = _ldif.get('domainAliasName', []) + disclaimer_text = _ldif.get('disclaimer', [''])[0] + + domains = _domains + _alias_domains + + for domain in domains: + handle_disclaimer(domain, disclaimer_text) + + conn.unbind() + logger.info('Connection closed.') + + +def dump_from_sql(): + """Dump disclaimer text from MySQL or PostgreSQL server.""" + logger.info("Connecting to SQL server '%s:%d' as user '%s' ..." % (settings.vmail_db_host, + int(settings.vmail_db_port), + settings.vmail_db_user)) + + conn = web.database(dbn=sql_dbn, + host=settings.vmail_db_host, + port=int(settings.vmail_db_port), + db=settings.vmail_db_name, + user=settings.vmail_db_user, + pw=settings.vmail_db_password) + + logger.info('Get all alias domains') + qr = conn.select('alias_domain', what='alias_domain, target_domain') + alias_domains = {} + for i in qr: + _alias_domain = str(i.alias_domain).lower() + _target_domain = str(i.target_domain).lower() + + if _target_domain in alias_domains: + alias_domains[_target_domain].append(_alias_domain) + else: + alias_domains[_target_domain] = [_alias_domain] + + # Search and get disclaimer. + logger.info('Get all primary domains') + qr = conn.select('domain', what='domain, disclaimer') + + # Dump disclaimer for every domain. + logger.info('Dumping...') + for r in qr: + domain = str(r.domain).lower() + disclaimer_text = r.disclaimer + + domains = [domain] + alias_domains.get(domain, []) + + logger.info(domain) + for domain in domains: + handle_disclaimer(domain, disclaimer_text) + + logger.info('Completed.') + + +if settings.backend == 'ldap': + dump_from_ldap() +elif settings.backend in ['mysql', 'pgsql']: + dump_from_sql() diff --git a/tools/dump_quarantined_mails.py b/tools/dump_quarantined_mails.py new file mode 100644 index 0000000..2946852 --- /dev/null +++ b/tools/dump_quarantined_mails.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Author: Zhang Huangbin +# Purpose: Dump quarantined emails to given directory (specified on command line). +# +# Usage: +# +# python dump_quarantined_mail.py /path/to/dir + +import os +import sys +import time +import web + +output_dir = sys.argv[1] +if not os.path.isdir(output_dir): + sys.exit("Output directory doesn't exist: %s" % output_dir) + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +from tools.ira_tool_lib import debug, get_db_conn + +web.config.debug = debug + +now = int(time.time()) +conn_amavisd = get_db_conn('amavisd') +conn_iredadmin = get_db_conn('iredadmin') + +# Get last time +last_time = 0 +try: + qr = conn_iredadmin.select('tracking', what='v', where="k='dump_quarantined_mail'", limit=1) + if qr: + last_time = int(qr[0].v) +except: + pass + +# Get value of all `quarantine.mail_id`. +try: + qr = conn_amavisd.select(['msgs', 'quarantine'], + what='msgs.mail_id AS mail_id', + where='msgs.mail_id=quarantine.mail_id AND msgs.time_num >= %d' % last_time, + group='msgs.mail_id') +except Exception as e: + print('<<< ERROR >>> {}'.format(repr(e))) + sys.exit() + +total = len(qr) +print("* Found {} quarantined emails in SQL db.".format(total)) + +counter = 1 +for r in qr: + mail_id = str(r.mail_id) + try: + records = conn_amavisd.select('quarantine', + what='mail_text', + where='mail_id = %s' % web.sqlquote(mail_id), + order='chunk_ind ASC') + + if not records: + continue + + # Combine mail_text as RAW mail message. + message = '' + for i in list(records): + for j in i.mail_text: + message += j + + # Write message to file + try: + eml_path = os.path.join(output_dir, 'spam-' + mail_id) + print("[{}/{}] Dumping email to file: {}".format(counter, total, eml_path)) + + f = open(eml_path, 'w') + f.write(message) + f.close() + except Exception as e: + print('<<< ERROR >>> cannot write file {}'.format(repr(e))) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) + + counter += 1 + +# Log last time. +conn_iredadmin.delete('tracking', where="k='dump_quarantined_mail'") +conn_iredadmin.insert('tracking', k='dump_quarantined_mail', v=now) diff --git a/tools/export_last_login.py b/tools/export_last_login.py new file mode 100644 index 0000000..e669087 --- /dev/null +++ b/tools/export_last_login.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +Query user last login info from (My)SQL database and display it in a more +readable format (plain text or html). + +Note: You need to follow this tutorial to enable last_login plugin in Dovecot: + https://docs.iredmail.org/track.user.last.login.html + +Usage: + + python3 export_last_login.py # in plain text format + python3 export_last_login.py html > export.html # in html format +""" +import os +import sys +import time +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools import ira_tool_lib +from libs.iredutils import epoch_seconds_to_gmt + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +if settings.backend == 'ldap': + conn = ira_tool_lib.get_db_conn('iredadmin') +else: + conn = ira_tool_lib.get_db_conn('vmail') + +# Get output format +try: + export_format = sys.argv[1] +except: + export_format = 'text' + +try: + qr = conn.select('last_login', + order='last_login DESC') +except Exception as e: + sys.exit("Query failed: {}".format(e)) + +if export_format == 'html': + _now = time.strftime('%Y-%d-%m %H:%M:%S') + + html = """ + + + + + + + +

User Last Login Time ({0})

+ + + + + + + + + + """.format(_now) + +counter = 1 +for row in qr: + username = row.username + seconds = row.last_login + last_login = epoch_seconds_to_gmt(seconds) + + if export_format == 'html': + html += """ + + + + + + """.format(counter, username, last_login) + else: + print("{:6} | {:30} | {}".format(counter, username, last_login)) + + counter += 1 + +if export_format == 'html': + html += """
#EmailTime (GMT)
{}{}{}
""" + print(html) diff --git a/tools/import_users.py b/tools/import_users.py new file mode 100644 index 0000000..1c1979f --- /dev/null +++ b/tools/import_users.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# Purpose: Read mail accounts from given plain text file (in specified format), +# then create them with iRedAdmin-Pro RESTful API interface. +# +# Usage: +# +# - Make sure your iRedAdmin-Pro has RESTful API interface enabled by +# following our tutorial: +# https://docs.iredmail.org/iredadmin-pro.restful.api.html#enable-restful-api +# +# - Generate file /opt/users.list which contains the mail accounts you want +# to import, one account per line, with account info stored in few fields: +# +# 1: [REQUIRED] user's full email address. +# 2: [REQUIRED] plain text or password hash which starts with the password +# scheme name. For example, "{SSHA}xxx", "{SSHA512}xxx". +# 3: [optional] mailbox quota in MB. Must be an integer number. +# 4: [optional] full display name. +# 5: [optional] list of mailing list addresses. If not empty, user will be +# assigned to given mailing lists as a member. +# +# Notes: +# +# - Multiple addresses must be separated by ":". +# - If mailing list doesn't exist, it will not be created automatically. +# 6: [optional] employeeid: employee id. +# +# NOTE: the separator "," for ending EMPTY optional fields is not required. +# +# Samples: +# +# user@domain.com, plain_password +# user@domain.com, plain_password, 1024, Zhang Huangbin, list1@domain.com:list2@domain.com +# user@domain.com, plain_password, , , list1@domain.com:list2@domain.com +# user@domain.com, plain_password, 1024, Zhang Huangbin +# +# - Update 3 parameters in this file: +# +# api_endpoint = '' +# verify_cert = True +# admin = 'postmaster@a.io' +# pw = 'password' +# +# - "api_endpoint" is the endpoint of iRedAdmin-Pro RESTful API. +# - With "verify_cert = True", a valid ssl cert is required on API +# server (https://). If you don't have a valid ssl cert yet, please set +# it to False. +# - "admin" is the email address of domain admin which has privilege to +# manage the email domain which you're going to import users to. +# - "pw" is plain password of domain admin. +# +# - Run commands below to create users listed in the "/opt/users.list" file: +# +# python import_users.py /opt/users.list + +import os +import sys +import requests +from requests.packages.urllib3.exceptions import InsecureRequestWarning +requests.packages.urllib3.disable_warnings(InsecureRequestWarning) + +# Endpoint of iRedAdmin-Pro RESTful API +api_endpoint = 'http://127.0.0.1:8080/api' + +# Verify SSL cert of API server. +# If you don't have a valid SSL cert yet, please set it to False. +verify_cert = True + +# Domain admin email address and password +admin = 'postmaster@a.io' +pw = 'www' + +# Define the order of fields in each line. Fields must be separated by comma. +# +# +# WARNING: For empty optional fields, a comma is still required as placeholder. +# +# Samples: +# +# user@domain.com, plain_password, , , +# user@domain.com, plain_password, 1024, Zhang Huangbin, list1@domain.com:list2@domain.com, +# user@domain.com, plain_password, , , list1@domain.com:list2@domain.com, +# user@domain.com, plain_password, 1024, Zhang Huangbin,, +# +field_map = ['mail', 'password', 'quota', 'name', 'groups', 'employeeid'] + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) +from libs import iredutils + + +def __get(url, data=None): + _url = api_endpoint + url + r = requests.get(_url, data=data, cookies=cookies, verify=verify_cert) + return r.json() + + +def __post(url, data=None): + _url = api_endpoint + url + r = requests.post(_url, data=data, cookies=cookies, verify=verify_cert) + return r.json() + + +def __put(url, data=None): + _url = api_endpoint + url + r = requests.put(_url, data=data, cookies=cookies, verify=verify_cert) + return r.json() + + +def __delete(url, data=None): + _url = api_endpoint + url + r = requests.delete(_url, data=data, cookies=cookies, verify=verify_cert) + return r.json() + + +def usage(): + pass + + +if len(sys.argv) != 2 or len(sys.argv) > 2: + print("Usage: $ python bulk_import.py /path/to/file") + usage() + sys.exit() +else: + file = sys.argv[1] + if not os.path.exists(file): + print("<<< ERROR >>> file does not exist: {}".format(file)) + sys.exit() + +# +# Login +# +r = requests.post(api_endpoint + '/login', + data={'username': admin, 'password': pw}, + verify=verify_cert) + +# Get returned JSON data +res = r.json() +if not res['_success']: + sys.exit('Login failed') + +cookies = r.cookies + +# Read user list. +f = open(file, 'rb') + +for line in f.readlines(): + line = iredutils.bytes2str(line.strip()) + fields = line.split(',') + + try: + d = {} + for (k, v) in zip(field_map, fields): + d[k] = v + except: + sys.exit("<<< ERROR >>> line has invalid format:\n{}".format(line)) + + # Get user mail address + mail = d.pop('mail') + mail.lower() + if not iredutils.is_email(mail): + sys.exit("<<< ERROR >>> line has invalid user email address: {}\nLine: {}".format(mail, line)) + + password = d.pop('password') + name = d.pop("name", mail.split("@", 1)[0]) + quota = d.pop("quota", "0") + + # Get mail address(es) of assigned mailing list(s) + groups = d.pop('groups', "") + groups.lower() + groups = [addr.lower().strip() for addr in groups.split(':') if iredutils.is_email(addr)] + + # Create user + res = __post('/user/' + mail, + data={'name': name, + 'password': password.strip(), + 'quota': quota}) + + if res['_success']: + print("[OK] Created user: {}".format(mail)) + else: + if res['_msg'] == 'ALREADY_EXISTS': + print("[SKIP] Account already exists: {}.".format(mail)) + continue + else: + sys.exit('<<< ERROR >>> failed to create user: {}'.format(res)) + + if password.startswith('{'): + res = __put('/user/' + mail, + data={'password_hash': password}) + + if res['_success']: + print(" |- [OK] Updated user password (hash): {}".format(mail)) + else: + sys.exit('<<< ERROR >>> failed to updated user password (hash): {}, error: {}'.format(mail, res)) + + if groups: + for group in groups: + res = __put('/ml/' + group, + data={'add_subscribers': mail, + 'require_confirm': 'no'}) + + if res['_success']: + print(" |- [OK] Subscribed user to mailing list: {} -> {}".format(mail, group)) + else: + print('<<< WARNING >>> failed to subscribe user to mailing list: {} -> {}, error: {}'.format(mail, group, res)) + + employeeid = d.pop("employeeid", "") + if employeeid: + res = __put('/user/' + mail, + data={'employeeid': employeeid}) + + if res['_success']: + print(" |- [OK] Updated employeeid: {}".format(mail)) + else: + sys.exit('<<< ERROR >>> failed to updated employeeid: {}, error: {}'.format(mail, res)) diff --git a/tools/ira_tool_lib.py b/tools/ira_tool_lib.py new file mode 100644 index 0000000..2d0cd7b --- /dev/null +++ b/tools/ira_tool_lib.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# Author: Zhang Huangbin +# Purpose: Library used by other scripts under tools/ directory. + +import os +import sys +import logging +import web + +debug = False + +# Set True to print SQL queries. +web.config.debug = debug + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from libs import iredutils + +backend = settings.backend +if backend in ['ldap', 'mysql']: + sql_dbn = 'mysql' +elif backend in ['pgsql']: + sql_dbn = 'postgres' +else: + sys.exit('Error: Unsupported backend (%s).' % backend) + +# logging +logger = logging.getLogger('iredadmin') +_ch = logging.StreamHandler(sys.stdout) +_formatter = logging.Formatter('* %(message)s') +_ch.setFormatter(_formatter) +logger.addHandler(_ch) +logger.setLevel(logging.INFO) + + +def get_db_conn(db_name): + if backend == 'ldap' and db_name in ['ldap', 'vmail']: + logger.error("""Please use code below to get LDAP connection cursor:\n + +from libs.ldaplib.core import LDAPWrap\n +_wrap = LDAPWrap()\n +conn = _wrap.conn\n""") + + return None + + try: + conn = web.database( + dbn=sql_dbn, + host=settings.__dict__[db_name + '_db_host'], + port=int(settings.__dict__[db_name + '_db_port']), + db=settings.__dict__[db_name + '_db_name'], + user=settings.__dict__[db_name + '_db_user'], + pw=settings.__dict__[db_name + '_db_password'], + ) + + conn.supports_multiple_insert = True + return conn + except Exception as e: + logger.error(e) + return None + + +# Log in `iredadmin.log` +def log_to_iredadmin(msg, event, admin='', username='', loglevel='info'): + conn = get_db_conn('iredadmin') + + try: + conn.insert('log', + admin=admin, + username=username, + event=event, + loglevel=loglevel, + msg=str(msg), + ip='127.0.0.1', + timestamp=iredutils.get_gmttime()) + except: + pass + + return None + + +def sql_count_id(conn, table, column='id', where=None): + if where: + qr = conn.select(table, + what='count(%s) as total' % column, + where=where) + else: + qr = conn.select(table, + what='count(%s) as total' % column) + if qr: + total = qr[0].total + else: + total = 0 + + return total diff --git a/tools/migrate_cluebringer_wblist_to_amavisd.py b/tools/migrate_cluebringer_wblist_to_amavisd.py new file mode 100644 index 0000000..1dba3a6 --- /dev/null +++ b/tools/migrate_cluebringer_wblist_to_amavisd.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Migrate Cluebringer white/blacklist to Amavisd database. +# +# Note: it's safe to execute this script as many times as you want, it won't +# generate duplicate records. + +import os +import sys +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from libs.iredutils import is_valid_amavisd_address +from libs.amavisd import wblist +from tools import ira_tool_lib + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +# Check database name to make sure it's Cluebringer +if settings.policyd_db_name != 'cluebringer': + sys.exit('Error: not a Cluebringer database.') + +logger.info('Establish SQL connection.') +conn = ira_tool_lib.get_db_conn('policyd') + +logger.info('Query white/blacklist info.') + +# Converted wblist +wl = [] +bl = [] + +# value of sql column: policy_groups.id +wl_id = None +bl_id = None +wb_ids = [] + +# query whitelist and/or blacklist. possible values: 'wl', 'bl'. +query_lists = [] + +# get policy_groups.id +qr = conn.select('policy_groups', what='id,name', where="name IN ('whitelists', 'blacklists')") +if qr: + for r in qr: + if r.name == 'whitelists': + wl_id = r.id + elif r.name == 'blacklists': + bl_id = r.id + + if wl_id: + logger.info('policy_groups.id: %d -> whitelists' % wl_id) + query_lists.append('wl') + wb_ids.append(wl_id) + + if bl_id: + logger.info('policy_groups.id: %d -> blacklists' % bl_id) + query_lists.append('bl') + wb_ids.append(bl_id) +else: + logger.info('No whitelist/blacklist found. Exit.') + sys.exit() + +logger.info('Query all whitelists and blacklists.') +qr = conn.select('policy_group_members', + vars={'wb_ids': wb_ids}, + what='policygroupid, member', + where='policygroupid IN $wb_ids AND disabled=0') + +if qr: + logger.info('Convert Cluebringer white/blacklists to Amavisd syntax format.') + for r in qr: + # Single IP Address: 192.168.2.10 + # CIDR formatted range of IP addresses: 192.168.2.10/31 + # Single user: user@example.com + # Entire domain: @example.com + # All sub-domains: .example.com + value = None + if is_valid_amavisd_address(r.member): + value = r.member + else: + # Convert from different syntax format + if r.member.startswith('.'): + tmp = '@' + r.member + if is_valid_amavisd_address(tmp): + value = tmp + else: + logger.info('[?] Discard record in improper format: %s, cannot convert.' % r.member) + elif '/' in r.member: + logger.info('[?] Discard record in improper format: %s. CIDR IP range is not supported.' % r.member) + + if value: + if r.policygroupid == wl_id: + wl.append(value) + else: + bl.append(value) + +if wl: + logger.info('Converted whitelisted: %d total' % len(wl)) +else: + logger.info('No whitelists found.') + +if bl: + logger.info('Converted blacklisted: %d total' % len(bl)) +else: + logger.info('No blacklists found.') + +confirm = input('Migrate converted white/blacklists to Amavisd database right now? [y|N]') +if confirm not in ['y', 'Y', 'yes', 'YES']: + logger.info('Exit without migrating to Amavisd database.') + sys.exit() + +# Import to Amavisd database. +try: + logger.info('Migrating, please wait ...') + wblist.add_wblist(account='@.', + wl_senders=wl, + bl_senders=bl, + flush_before_import=False) + + logger.info("Don't forget to enable iRedAPD plugin 'amavisd_wblist' in /opt/iredapd/settings.py.") +except Exception as e: + logger.info(str(e)) + +# Ask to delete wblist in cluebringer +confirm = input('Delete all white/blacklists stored in Cluebringer database? [y|N]') +if confirm not in ['y', 'Y', 'yes', 'YES']: + logger.info('Exit without deleting Cluebringer white/blacklists.') + sys.exit() + +conn.delete('policy_group_members', vars={'wb_ids': wb_ids}, where='policygroupid IN $wb_ids') +conn.delete('policy_groups', vars={'wb_ids': wb_ids}, where='id IN $wb_ids') +conn.delete('policy_members', where="destination='%%internal_domains' AND source IN ('%%whitelists', '%%blacklists')") + +# Get policies.id +qr = conn.select('policies', what='id', where="name IN ('whitelists', 'blacklists')") +if qr: + pids = [r.id for r in qr] + + conn.delete('access_control', vars={'pids': pids}, where='policyid IN $pids') + conn.delete('policies', vars={'pids': pids}, where='id IN $pids') + +logger.info('DONE') diff --git a/tools/notify_quarantined_recipients.html b/tools/notify_quarantined_recipients.html new file mode 100644 index 0000000..9e7c000 --- /dev/null +++ b/tools/notify_quarantined_recipients.html @@ -0,0 +1,31 @@ + + + + + + + + + +

Quarantined mails will be kept for %(quar_keep_days)d days, please login to self-service site to manage them: %(iredadmin_url)s.

+ +

Date and time are in time zone: %(timezone)s.

+ + + + + + + + + + + %(quar_mail_info)s + +
SubjectSenderSpam LevelTime
+ + diff --git a/tools/notify_quarantined_recipients.py b/tools/notify_quarantined_recipients.py new file mode 100644 index 0000000..0664db0 --- /dev/null +++ b/tools/notify_quarantined_recipients.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Notify local recipients (via email) that they have emails +# quarantined on server and not delivered to their mailbox. + +# Usage: +# +# - Set a correct URL in iRedAdmin-Pro config file `settings.py`, so that +# users can manage quarantined email within received notification email: +# +# # URL of your iRedAdmin-Pro login page which will be shown in notification +# # email, so that user can login to manage quarantined emails. +# # Sample: 'https://your_server.com/iredadmin/' +# # +# # Note: mail domain must have self-service enabled, otherwise normal +# # mail user cannot login to iRedAdmin-Pro for self-service. +# NOTIFICATION_URL_SELF_SERVICE = 'https://[your_server]/iredadmin/' +# +# - Setup a cron job to run this script every 6 or 12, 24 hours, it's up to +# you. Sample cron job (every 12 hours): +# +# 1 */12 * * * python /path/to/notify_quarantined_recipients.py >/dev/null +# +# Available arguments: +# +# --force-all: +# Send notification to all users (who have email quarantined). +# +# --force-all-time: +# Notify users for their all quarantined emails instead of just new +# ones since last notification. +# +# --notify-backupmx +# Send notification to all recipients under backup mx domain +# +# - Also, it's ok to run this script manually: +# +# # python notify_quarantined_recipients.py [arg1 arg2 arg3 ...] + +# Customization +# +# - This script sends email via /usr/sbin/sendmail command by default, it +# should work quite well and has better performance. if you still prefer +# to send notification email via smtp, please set proper smtp server and +# account info in iRedAdmin-Pro config file `settings.py`: +# +# NOTIFICATION_SMTP_SERVER = 'localhost' +# NOTIFICATION_SMTP_PORT = 587 +# NOTIFICATION_SMTP_STARTTLS = True +# NOTIFICATION_SMTP_USER = '' +# NOTIFICATION_SMTP_PASSWORD = '' +# +# - To custom mail subject of notification email, please define below +# variable in iRedAdmin-Pro config file `settings.py`: +# +# # Subject of notification email. +# NOTIFICATION_QUARANTINE_MAIL_SUBJECT = '[Attention] You have emails quarantined and not delivered to mailbox' +# +# - To custom HTML template file, please create your own template file with +# correct name in either place: +# +# - `/opt/iredmail/custom/iredadmin/notify_quarantined_recipients.html` +# +# This file is used if your iRedMail server was deployed with the +# iRedMail Easy platform (https://www.iredmail.org/easy.html), easy +# for iRedAdmin-Pro upgrade. +# +# - `tools/notify_quarantined_recipients.html.custom` under iRedAdmin-Pro +# directory. +# +# General use. Note: there's a `.custom` suffix in file name. +# +# If no custom file, `tools/notify_quarantined_recipients.html` will be used. +# +# How it works: +# +# - Mail user login to iRedAdmin-Pro (self-service) and choose to receive +# notification email when there's email quarantined. +# +# - OpenLDAP: user will be assigned `enabledService=quar_notify`. +# - SQL backends: column `mailbox.settings` contains `quar_notify:yes`. +# +# - This script queries SQL/LDAP database to see who are willing to receive +# a notification email. +# +# - This script checks Amavisd database to get info of quarantined mails +# for these users. + +import os +import sys +import time +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from email.header import Header +import web + +os.environ['LC_ALL'] = 'C' + +script_dir = os.path.abspath(os.path.dirname(__file__)) +rootdir = script_dir + '/../' +sys.path.insert(0, rootdir) + +now = int(time.time()) + +import settings +from libs import iredutils +from libs.ireddate import utc_to_timezone +from tools import ira_tool_lib + +web.config.debug = ira_tool_lib.debug +logger = ira_tool_lib.logger + +backend = settings.backend + +# Read template HTML file. +custom_easy_tmpl = "/opt/iredmail/custom/iredadmin/notify_quarantined_recipients.html" +custom_tmpl = os.path.join(rootdir, 'tools', 'notify_quarantined_recipients.html.custom') +default_tmpl = os.path.join(rootdir, 'tools', 'notify_quarantined_recipients.html') + +if os.path.isfile(custom_easy_tmpl): + html_tmpl = custom_easy_tmpl +elif os.path.isfile(custom_tmpl): + html_tmpl = custom_tmpl +else: + html_tmpl = default_tmpl + +# Info used in notification email. +mail_subject = settings.NOTIFICATION_QUARANTINE_MAIL_SUBJECT +smtp_user = settings.NOTIFICATION_SMTP_USER +iredadmin_url = settings.NOTIFICATION_URL_SELF_SERVICE + +# Use '--force-all' option to notify all mail users. +force_all_users = '--force-all' in sys.argv or False +force_all_time = '--force-all-time' in sys.argv or False +notify_backupmx = '--notify-backupmx' in sys.argv or False + +# Backup MX domains. +# We may not have any accounts under backup mx domain, so if sys admin chooses +# to notify recipients in backup mx domain, we send the notification also. +backupmx_domains = [] + +# List of target users' email address. +target_users = [] + +# Get list of users (email) who asked to receive notification email. +if settings.backend == 'ldap': + from libs.ldaplib.core import LDAPWrap + _wrap = LDAPWrap() + conn_ldap = _wrap.conn + + # Get users who ask to get a notification email under each domain. + if force_all_users: + q_filter = '(&(objectClass=mailUser)(accountStatus=active))' + else: + q_filter = '(&(objectClass=mailUser)(accountStatus=active)(enabledService=quar_notify))' + + try: + qr = conn_ldap.search_s(settings.ldap_basedn, + 2, # ldap.SCOPE_SUBTREE, + q_filter, + ['mail']) + for (_dn, _ldif) in qr: + _ldif = iredutils.bytes2str(_ldif) + target_users += _ldif.get('mail', []) + except Exception as e: + logger.info('<< ERROR >> Error while querying mail users: %s' % repr(e)) + + if notify_backupmx: + # Query all backup mx domains + q_filter = '(&(objectClass=mailDomain)(accountStatus=active)(domainBackupMX=yes)(mtaTransport=relay:*))' + + try: + qr = conn_ldap.search_s(settings.ldap_basedn, + 1, # ldap.SCOPE_ONELEVEL, + q_filter, + ['domainName', 'domainAliasName']) + for (_dn, _ldif) in qr: + _ldif = iredutils.bytes2str(_ldif) + backupmx_domains += _ldif.get('domainName', []) + backupmx_domains += _ldif.get('domainAliasName', []) + except Exception as e: + logger.info('<< ERROR >> Error while querying backup MX domains: %s' % repr(e)) + +elif settings.backend in ['mysql', 'pgsql']: + conn_vmaildb = ira_tool_lib.get_db_conn('vmail') + + # Get all users who asked to receive notification email. + if force_all_users: + sql_where = 'active=1' + else: + sql_where = 'settings LIKE %s AND active=1' % web.sqlquote('%' + 'quar_notify:' + '%') + + qr = conn_vmaildb.select('mailbox', + what='username', + where=sql_where) + + for r in qr: + target_users.append(r.username) + + if notify_backupmx: + # Get all backup mx domains + qr = conn_vmaildb.select('domain', + what='domain', + where='backupmx=1 AND active=1') + for i in qr: + backupmx_domains += [str(i.domain).lower()] + + if backupmx_domains: + # Get all alias domains + qr = conn_vmaildb.select('alias_domain', + vars={'domains': backupmx_domains}, + what='alias_domain', + where='target_domain IN $domains') + + for i in qr: + backupmx_domains += [str(i.alias_domain).lower()] + +if not (target_users or backupmx_domains): + logger.debug('No user asks to receive notification email. Exit.') + sys.exit() + +mail_body_template = open(html_tmpl).read() + +conn_amavisd = ira_tool_lib.get_db_conn('amavisd') +conn_iredadmin = ira_tool_lib.get_db_conn('iredadmin') + +reversed_backupmx_domains = [] +target_backupmx_users = [] +if backupmx_domains: + for d in backupmx_domains: + rd = d.split('.') + rd.reverse() + rd = '.'.join(rd) + + reversed_backupmx_domains.append(rd) + + qr = conn_amavisd.select('maddr', + vars={'rcpt': reversed_backupmx_domains}, + what='email', + where='domain IN $rcpt') + for i in qr: + _email = iredutils.bytes2str(i.email) + target_backupmx_users.append(_email) + + logger.info('%d backup MX domains (%d users) will receive notification email.' % (len(backupmx_domains), len(target_backupmx_users))) + +logger.debug('%d users are willing to receive notification email.' % len(target_users)) + +target_users += target_backupmx_users + +# Notify users. +for user in target_users: + # Get maddr.id of recipient + qr = conn_amavisd.select('maddr', + vars={'rcpt': user}, + what='id', + where='email=$rcpt', + limit=1) + if qr: + rid = qr[0].id + else: + logger.debug('[SKIP] No log of user: ' + user) + continue + + # Get info of quarantined mails + sql_what = 'msgrcpt.rid AS rid,' \ + + 'msgs.mail_id AS mail_id,' \ + + 'msgs.subject AS subject,' \ + + 'msgs.from_addr AS from_addr,' \ + + 'msgs.spam_level AS spam_level,' \ + + 'msgs.time_num' + + sql_where = """msgrcpt.rid=$rid AND msgs.mail_id=msgrcpt.mail_id AND msgs.quar_type='Q'""" + + last_notify_time = 0 + if not force_all_time: + # Get last time + try: + qr = conn_iredadmin.select('tracking', what='v', where="k='quarantine_notify_time'", limit=1) + if qr: + last_notify_time = int(qr[0].v) or 0 + except: + pass + + if last_notify_time: + sql_where += """ AND msgs.time_num >= %s""" % last_notify_time + + qr = conn_amavisd.select(['msgs', 'msgrcpt'], + vars={'rid': rid}, + what=sql_what, + where=sql_where, + order='msgs.time_num DESC') + + if not qr: + logger.debug('[SKIP] No quarantined emails for %s.' % user) + continue + + total = len(qr) + + # Group messages by date. + info_by_date = {} + + quar_mail_info = '\n' + + # Create a HTML table to present quarantined emails. + for rcd in qr: + # time format: Apr 4, 2015 + dt = iredutils.epoch_seconds_to_gmt(iredutils.bytes2str(rcd.time_num)) + time_with_tz = utc_to_timezone(dt=dt, timezone=settings.LOCAL_TIMEZONE) + try: + time_tuple = time_with_tz.timetuple() + except: + time_tuple = time.strptime(time_with_tz, '%Y-%m-%d %H:%M:%S') + + mail_date = time.strftime('%b %d, %Y', time_tuple) + mail_time = time.strftime('%H:%M:%S', time_tuple) + + info = '' + '\n' + info += '' + iredutils.bytes2str(rcd.subject) + '' + '\n' + info += '' + iredutils.bytes2str(rcd.from_addr) + '' + '\n' + info += '' + iredutils.bytes2str(rcd.spam_level) + '' + '\n' + info += '' + mail_time + '' + '\n' + info += '' + '\n\n' + + if mail_date not in info_by_date: + info_by_date[mail_date] = [] + + info_by_date[mail_date].append(info) + + for _date in sorted(list(info_by_date.keys()), reverse=True): + quar_mail_info += '' + _date + '' + '\n' + for r in info_by_date[_date]: + quar_mail_info += r + + msg = MIMEMultipart('alternative') + + msg['Subject'] = Header(mail_subject % {'total': total}, 'utf-8') + msg['To'] = user + + if settings.NOTIFICATION_SENDER_NAME: + msg['From'] = '{} <{}>'.format(Header(settings.NOTIFICATION_SENDER_NAME, 'utf-8'), smtp_user) + else: + msg['From'] = Header(smtp_user, 'utf-8') + + mail_body = mail_body_template % {'quar_mail_info': quar_mail_info, + 'quar_keep_days': settings.AMAVISD_REMOVE_QUARANTINED_IN_DAYS, + 'iredadmin_url': iredadmin_url, + 'timezone': settings.LOCAL_TIMEZONE} + + # HTML email must contain text and html part with same content, otherwise + # it will be considered as not well-formated email. + body_part_plain = MIMEText(mail_body, 'plain', 'utf-8') + msg.attach(body_part_plain) + + body_part_html = MIMEText(mail_body, 'html', 'utf-8') + msg.attach(body_part_html) + + msg_string = msg.as_string() + + ret = iredutils.sendmail(recipients=user, message_text=msg_string) + if ret[0]: + logger.info('+ %s: %d mails.' % (user, total)) + else: + logger.info('+ << ERROR >> Error while sending notification email to {}: {}'.format(user, ret[1])) + +# Log last notify time. +conn_iredadmin.delete('tracking', where="k='quarantine_notify_time'") +conn_iredadmin.insert('tracking', k='quarantine_notify_time', v=now) diff --git a/tools/promote_user_to_global_admin.py b/tools/promote_user_to_global_admin.py new file mode 100644 index 0000000..99b2550 --- /dev/null +++ b/tools/promote_user_to_global_admin.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Author: Zhang Huangbin +# Purpose: Promote given user to be a global admin. +# FYI https://docs.iredmail.org/promote.user.to.be.global.admin.html +# Usage: +# python3 promote_to_global_admin.py + + +def usage(): + print("""Usage: Run this script with user email address: + + # python3 promote_to_global_admin.py user@domain.com + """) + + +import os +import sys +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools.ira_tool_lib import debug, get_db_conn +from libs.iredutils import is_email + +backend = settings.backend +web.config.debug = debug + +# Check arguments +if len(sys.argv) == 2: + email = sys.argv[1] + + if not is_email(email): + usage() + sys.exit() +else: + usage() + sys.exit() + +if backend == 'ldap': + from libs.ldaplib.core import LDAPWrap + from libs.ldaplib import ldaputils + _wrap = LDAPWrap() + conn = _wrap.conn + + dn = ldaputils.rdn_value_to_user_dn(email) + mod_attrs = ldaputils.attr_ldif(attr="enabledService", value="domainadmin", mode="add") + mod_attrs += ldaputils.attr_ldif(attr="domainGlobalAdmin", value="yes", mode="add") + + try: + conn.modify_s(dn, mod_attrs) + print("User {} is now a global admin.".format(email)) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) + +elif backend in ['mysql', 'pgsql']: + conn = get_db_conn('vmail') + try: + conn.update("mailbox", + isadmin=1, + isglobaladmin=1, + where="username='{}'".format(email)) + + conn.insert("domain_admins", + username=email, + domain="ALL") + + print("User {} is now a global admin.".format(email)) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) diff --git a/tools/reset_user_password.py b/tools/reset_user_password.py new file mode 100644 index 0000000..5e83821 --- /dev/null +++ b/tools/reset_user_password.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Author: Zhang Huangbin +# Purpose: Update user password. +# Usage: +# python reset_user_password.py + + +def usage(): + print("""Usage: Run this script with user email address and new plain password: + + # python3 reset_user_password.py user@domain.com 123456 + """) + + +import os +import sys +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools.ira_tool_lib import debug, get_db_conn +from libs.iredutils import is_email +from libs.iredpwd import generate_password_hash + +backend = settings.backend +web.config.debug = debug + +# Check arguments +if len(sys.argv) == 3: + email = sys.argv[1] + pw = sys.argv[2] + + if not is_email(email): + usage() + sys.exit() +else: + usage() + sys.exit() + +pw_hash = generate_password_hash(pw) +if backend == 'ldap': + from libs.ldaplib.core import LDAPWrap + from libs.ldaplib import ldaputils + _wrap = LDAPWrap() + conn = _wrap.conn + + dn = ldaputils.rdn_value_to_user_dn(email) + mod_attrs = ldaputils.mod_replace('userPassword', pw_hash) + mod_attrs += ldaputils.mod_replace('shadowLastChange', ldaputils.get_days_of_shadow_last_change()) + + try: + conn.modify_s(dn, mod_attrs) + print("[{}] Password has been reset.".format(email)) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) +elif backend in ['mysql', 'pgsql']: + conn = get_db_conn('vmail') + try: + conn.update('mailbox', + password=pw_hash, + passwordlastchange=web.sqlliteral('NOW()'), + where="username='{}'".format(email)) + + print("[{}] Password has been reset.".format(email)) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) diff --git a/tools/update_mailbox_quota.py b/tools/update_mailbox_quota.py new file mode 100644 index 0000000..0aa51fd --- /dev/null +++ b/tools/update_mailbox_quota.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python + +# Author: Zhang Huangbin +# Purpose: Update mailbox quota for one user or multiple users. +# Note: Mailbox quota size unit is bytes. for example, size `104857600` is 100 MB. + + +def usage(): + print("""Usage: + + 1) Update mailbox quota for one user. + + To simply update one user's quota, run this script with user's email + address and new quota size (in bytes). For example: + + # python3 update_mailbox_quota.py user@domain.com 2048576000 + + 2) Update mailbox quota for multiple users. + + - Create text file "new_quota.txt", each line contains one email address + and the new quota size (in bytes). + + user1@domain.com 20480000 + user2@domain.com 102400000 + user3@domain.com 409600000 + + - Run this script with this file: + + # python3 update_mailbox_quota.py new_quota.txt + """) + + +import os +import sys +import web + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools.ira_tool_lib import debug, logger, get_db_conn +from libs.iredutils import is_email + +backend = settings.backend +logger.info('Backend: {}'.format(backend)) + +web.config.debug = debug + +# List of (email, quota) tuples. +users = [] + +# Check arguments +if len(sys.argv) == 2: + # bulk update + text_file = sys.argv[1] + if not os.path.isfile(text_file): + sys.exit('<<< ERROR>>> Not a regular file: %s' % text_file) + + # Get all (email, quota) tuples. + f = open(text_file) + for _line in f.readlines(): + (_email, _quota) = _line.strip().split(' ', 1) + if is_email(_email) and _quota.isdigit(): + users += [(_email, _quota)] + else: + print("[SKIP] no valid email address or quota: {}".format(_line)) + +elif len(sys.argv) == 3: + # update single user + _email = sys.argv[1] + _quota = sys.argv[2] + + if is_email(_email): + users += [(_email, _quota)] + else: + sys.exit('<<< ERROR >>> Not an valid email address: %s' % _email) +else: + usage() + +total = len(users) +logger.info('{} users in total.'.format(total)) + +count = 1 +if backend == 'ldap': + from libs.ldaplib.core import LDAPWrap + from libs.ldaplib.ldaputils import rdn_value_to_user_dn, mod_replace + _wrap = LDAPWrap() + conn = _wrap.conn + + for (_email, _quota) in users: + logger.info('(%d/%d) Updating %s -> %s' % (count, total, _email, _quota)) + dn = rdn_value_to_user_dn(_email) + mod_attrs = mod_replace('mailQuota', _quota) + try: + conn.modify_s(dn, mod_attrs) + except Exception as e: + print("<<< ERROR >>> {}".format(e)) +elif backend in ['mysql', 'pgsql']: + conn = get_db_conn('vmail') + for (_email, _quota) in users: + logger.info('(%d/%d) Updating %s -> %s' % (count, total, _email, _quota)) + conn.update('mailbox', + quota=int(_quota), + where="username='%s'" % _email) diff --git a/tools/update_password_in_csv.py b/tools/update_password_in_csv.py new file mode 100644 index 0000000..efdf7d1 --- /dev/null +++ b/tools/update_password_in_csv.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +# Author: Zhang Huangbin +# Purpose: Update user passwords from records in a CSV file. + +import os +import sys +import web + + +def usage(): + print("""Usage: + + - Store the email address and new password in a plain text file, e.g. + 'passwords.csv'. format is: + + + + Samples: + + user1@domain.com pF4mTq4jaRzDLlWl + user2@domain.com SPhkTUlZs1TBxvmJ + user3@domain.com 8deNR8IBLycRujDN + + - Run this script with this file: + + python3 update_password_in_csv.py passwords.csv + """) + + +os.environ['LC_ALL'] = 'C' + +rootdir = os.path.abspath(os.path.dirname(__file__)) + '/../' +sys.path.insert(0, rootdir) + +import settings +from tools.ira_tool_lib import debug, logger, get_db_conn +from libs.iredutils import is_email +from libs.iredpwd import generate_password_hash + +backend = settings.backend +logger.info('Backend: %s' % backend) + +web.config.debug = debug + +logger.info('Parsing command line arguments.') + +# File which stores email and quota. +text_file = '' + +# The separator +column_separator = ' ' + +# List of (email, quota) tuples. +users = [] + +# Check arguments +if len(sys.argv) == 2: + text_file = sys.argv[1] + if not os.path.isfile(text_file): + sys.exit('<<< ERROR>>> Not a regular file: %s' % text_file) + + # Get all (email, password) tuples. + f = open(text_file) + line_num = 0 + for _line in f.readlines(): + line_num += 1 + (_email, _pw) = _line.split(column_separator, 1) + if is_email(_email): + users += [(_email, _pw)] + else: + print("[SKIP] line {}: no valid email address: {}".format(line_num, _line)) + f.close() +else: + usage() + +total = len(users) +logger.info('%d users in total.' % total) + +count = 1 +if backend == 'ldap': + from libs.ldaplib.core import LDAPWrap + from libs.ldaplib.ldaputils import rdn_value_to_user_dn, mod_replace + _wrap = LDAPWrap() + conn = _wrap.conn + + for (_email, _pw) in users: + logger.info('(%d/%d) Updating %s' % (count, total, _email)) + + dn = rdn_value_to_user_dn(_email) + pw_hash = generate_password_hash(_pw) + mod_attrs = mod_replace('userPassword', pw_hash) + try: + conn.modify_s(dn, mod_attrs) + except Exception as e: + print("<<< ERROR >>> {}".format(repr(e))) +elif backend in ['mysql', 'pgsql']: + conn = get_db_conn('vmail') + for (_email, _pw) in users: + logger.info('(%d/%d) Updating %s' % (count, total, _email)) + pw_hash = generate_password_hash(_pw) + conn.update('mailbox', + password=pw_hash, + where="username='%s'" % _email) diff --git a/tools/upgrade_iredadmin.sh b/tools/upgrade_iredadmin.sh new file mode 100644 index 0000000..0f30bb0 --- /dev/null +++ b/tools/upgrade_iredadmin.sh @@ -0,0 +1,963 @@ +#!/usr/bin/env bash + +# Purpose: Upgrade iRedAdmin from old release. +# Works with both iRedAdmin open source edition or iRedAdmin-Pro. + +# USAGE: +# +# # cd /path/to/iRedAdmin-xxx/tools/ +# # bash upgrade_iredadmin.sh +# +# Notes: +# +# * it uses sql username 'root' by default to connect to sql database. If you +# are using a remote SQL database which you don't have root privilege, +# please specify the sql username on command line with 'SQL_IREDADMIN_USER' +# parameter like this: +# +# SQL_IREDADMIN_USER='iredadmin' bash upgrade_iredadmin.sh +# +# * it reads sql password for given sql user from /root/.my.cnf by default. +# if you use a different file, please specify the file on command line with +# 'MY_CNF' parameter like this: +# +# MY_CNF='/root/.my.cnf-iredadmin' SQL_IREDADMIN_USER='iredadmin' bash upgrade_iredadmin.sh + +export LC_ALL='C' +export IRA_HTTPD_USER='iredadmin' +export IRA_HTTPD_GROUP='iredadmin' + +export SYS_ROOT_USER='root' + +# If you don't have root privilege, use another sql user instead. +export SQL_IREDADMIN_USER="${SQL_IREDADMIN_USER:=root}" +export MY_CNF="${MY_CNF:=/root/.my.cnf}" +export CMD_MYSQL="mysql --defaults-file=${MY_CNF} -u ${SQL_IREDADMIN_USER}" + +# Check OS to detect some necessary info. +export KERNEL_NAME="$(uname -s | tr '[a-z]' '[A-Z]')" + +export NGINX_PID_FILE='/var/run/nginx.pid' +export NGINX_SNIPPET_CONF='/etc/nginx/templates/iredadmin.tmpl' +export NGINX_SNIPPET_CONF2='/etc/nginx/templates/iredadmin-subdomain.tmpl' +# iRedMail-0.9.7 +export NGINX_SNIPPET_CONF3='/etc/nginx/conf.d/default.conf' + +export USE_SYSTEMD='NO' +if which systemctl &>/dev/null; then + export USE_SYSTEMD='YES' + export SYSTEMD_SERVICE_DIR='/lib/systemd/system' + export SYSTEMD_SERVICE_DIR2='/etc/systemd/system' + export SYSTEMD_SERVICE_USER_DIR='/etc/systemd/system/multi-user.target.wants/' +fi + +# Python. +export CMD_PYTHON3='/usr/bin/python3' +export CMD_PIP3='/usr/bin/pip3' + +# uwsgi +export CMD_UWSGI='/usr/bin/uwsgi' + +# If uwsgi is installed with pip, plugins are compiled into core binary +# directly, not plugins are installed separately. +# Mainly used on RHEL/CentOS/Rocky/Alma. +export UWSGI_HAS_PLUGINS="YES" + +if [ X"${KERNEL_NAME}" == X"LINUX" ]; then + # Note: RHEL has minor version number in VERSION_ID. + export DISTRO_VERSION=$(awk -F'"' '/^VERSION_ID=/ {print $2}' /etc/os-release | awk -F'.' '{print $1}') + + if [ -f /etc/redhat-release ]; then + # RHEL/CentOS + export DISTRO='RHEL' + + # Installed with pip. + export CMD_UWSGI='/usr/sbin/uwsgi' + + if [[ -x "/usr/local/bin/uwsgi" ]]; then + export CMD_UWSGI='/usr/local/bin/uwsgi' + export UWSGI_HAS_PLUGINS="NO" + fi + + export HTTPD_RC_SCRIPT_NAME='httpd' + export CRON_SPOOL_DIR='/var/spool/cron' + + if [[ -L /opt/www/iredadmin ]]; then + export HTTPD_SERVERROOT='/opt/www' + else + export HTTPD_SERVERROOT='/var/www' + fi + elif [ -f /etc/lsb-release ]; then + # Ubuntu + export DISTRO='UBUNTU' + + export HTTPD_RC_SCRIPT_NAME='apache2' + export CRON_SPOOL_DIR='/var/spool/cron/crontabs' + + if [ -L /opt/www/iredadmin ]; then + export HTTPD_SERVERROOT='/opt/www' + else + export HTTPD_SERVERROOT='/usr/share/apache2' + fi + elif [ -f /etc/debian_version ]; then + # Debian + export DISTRO='DEBIAN' + + export HTTPD_RC_SCRIPT_NAME='apache2' + export CRON_SPOOL_DIR='/var/spool/cron/crontabs' + + if [ -L /opt/www/iredadmin ]; then + export HTTPD_SERVERROOT='/opt/www' + else + export HTTPD_SERVERROOT='/usr/share/apache2' + fi + elif [ -f /etc/SuSE-release ]; then + # openSUSE + export DISTRO='SUSE' + export HTTPD_SERVERROOT='/srv/www' + export HTTPD_RC_SCRIPT_NAME='apache2' + export CRON_SPOOL_DIR='/var/spool/cron' + else + echo "<<< ERROR >>> Cannot detect Linux distribution name. Exit." + echo "Please contact support@iredmail.org to solve it." + exit 255 + fi +elif [ X"${KERNEL_NAME}" == X'FREEBSD' ]; then + export DISTRO='FREEBSD' + export SYSRC='/usr/sbin/sysrc' + + export CMD_PYTHON3='/usr/local/bin/python3' + export CMD_UWSGI='/usr/local/bin/uwsgi' + + [ -x /usr/local/bin/pip-3.8 ] && export CMD_PIP3='/usr/local/bin/pip-3.8' + [ -x /usr/local/bin/pip3 ] && export CMD_PIP3='/usr/local/bin/pip3' + [ -x /usr/local/bin/pip ] && export CMD_PIP3='/usr/local/bin/pip' + + export CRON_SPOOL_DIR='/var/cron/tabs' + export NGINX_SNIPPET_CONF='/usr/local/etc/nginx/templates/iredadmin.tmpl' + export NGINX_SNIPPET_CONF2='/usr/local/etc/nginx/templates/iredadmin-subdomain.tmpl' + export NGINX_SNIPPET_CONF3='/usr/local/etc/nginx/conf.d/default.conf' + + if [ -L /opt/www/iredadmin ]; then + export HTTPD_SERVERROOT='/opt/www' + else + export HTTPD_SERVERROOT='/usr/local/www' + fi + + if [ -f /usr/local/etc/rc.d/apache24 ]; then + export HTTPD_RC_SCRIPT_NAME='apache24' + else + export HTTPD_RC_SCRIPT_NAME='apache22' + fi +elif [ X"${KERNEL_NAME}" == X'OPENBSD' ]; then + export CMD_PYTHON3='/usr/local/bin/python3' + export CMD_PIP3='/usr/local/bin/pip3' + export CMD_UWSGI='/usr/local/bin/uwsgi' + export DISTRO='OPENBSD' + export CRON_SPOOL_DIR='/var/cron/tabs' + + if [[ -h /opt/www/iredadmin ]]; then + export HTTPD_SERVERROOT='/opt/www' + else + export HTTPD_SERVERROOT='/var/www' + fi +else + echo "Cannot detect Linux/BSD distribution. Exit." + echo "Please contact author iRedMail team to solve it." + exit 255 +fi + +export CRON_FILE_ROOT="${CRON_SPOOL_DIR}/${SYS_ROOT_USER}" + +# Optional argument to set the directory which stores iRedAdmin. +if [ $# -gt 0 ]; then + if [ -d ${1} ]; then + export HTTPD_SERVERROOT="${1}" + fi + + if echo ${HTTPD_SERVERROOT} | grep '/iredadmin/*$' > /dev/null; then + export HTTPD_SERVERROOT="$(dirname ${HTTPD_SERVERROOT})" + fi +fi + +# iRedAdmin directory and config file. +export IRA_ROOT_DIR="${HTTPD_SERVERROOT}/iredadmin" +export IRA_CONF_PY="${IRA_ROOT_DIR}/settings.py" +export IRA_CUSTOM_CONF_PY="${IRA_ROOT_DIR}/custom_settings.py" + +enable_service() { + srv="$1" + + echo "* Enable service: ${srv}" + if [ X"${DISTRO}" == X'RHEL' ]; then + if [ X"${USE_SYSTEMD}" == X'YES' ]; then + systemctl enable $srv + else + chkconfig --level 345 $srv on + fi + elif [ X"${DISTRO}" == X'DEBIAN' -o X"${DISTRO}" == X'UBUNTU' ]; then + if [ X"${USE_SYSTEMD}" == X'YES' ]; then + systemctl enable $srv + else + update-rc.d $srv defaults + fi + elif [ X"${DISTRO}" == X'FREEBSD' ]; then + ${SYSRC} -f /etc/rc.conf.local ${srv}_enable=YES + elif [ X"${DISTRO}" == X'OPENBSD' ]; then + rcctl enable $srv + fi +} + +restart_service() { + srv="$1" + + if [ X"${KERNEL_NAME}" == X'LINUX' ]; then + if [ X"${USE_SYSTEMD}" == X'YES' ]; then + systemctl restart ${srv} + else + service ${srv} restart + fi + elif [ X"${KERNEL_NAME}" == X'FREEBSD' ]; then + service ${srv} restart + elif [ X"${KERNEL_NAME}" == X'OPENBSD' ]; then + rcctl restart ${srv} + fi + + if [ X"$?" != X'0' ]; then + echo "Failed, please restart service manually and check its log file." + fi +} + +restart_web_service() +{ + export web_service="${HTTPD_RC_SCRIPT_NAME}" + if [ -f ${NGINX_PID_FILE} ]; then + if [ -n "$(cat ${NGINX_PID_FILE})" ]; then + export web_service="iredadmin" + fi + fi + + echo "* Restarting ${web_service} service." + if [ X"${KERNEL_NAME}" == X'LINUX' ]; then + # The uwsgi script on CentOS 6 has problem with 'restart' action, + # 'stop' with few seconds sleep fixes it. + if [ X"${DISTRO}" == X'RHEL' -a X"${web_service}" == X'uwsgi' ]; then + service ${web_service} stop + sleep 5 + service ${web_service} start + else + service ${web_service} restart + fi + elif [ X"${KERNEL_NAME}" == X'FREEBSD' ]; then + service ${web_service} restart + elif [ X"${KERNEL_NAME}" == X'OPENBSD' ]; then + rcctl restart ${web_service} + fi + + if [ X"$?" != X'0' ]; then + echo "Failed, please restart Apache web server or 'iredadmin' (if you're running Nginx as web server) manually." + fi +} + +check_mlmmjadmin_installation() +{ + if [ ! -e /opt/mlmmjadmin ]; then + echo "<<< ERROR >>> No mlmmjadmin installation found (/opt/mlmmjadmin)." + echo "<<< ERROR >>> Please follow iRedMail upgrade tutorials to the latest" + echo "<<< ERROR >>> stable release first, then come back to upgrade iRedAdmin-Pro." + echo "<<< ERROR >>> mlmmj and mlmmjadmin was first introduced in iRedMail-0.9.8." + echo "<<< ERROR >>> https://docs.iredmail.org/iredmail.releases.html" + exit 255 + fi +} + +remove_pkg() { + echo "Remove package(s): $@" + if [ X"${DISTRO}" == X'RHEL' ]; then + yum remove -y $@ + fi +} + +install_pkg() +{ + echo "Install package(s): $@" + + if [ X"${DISTRO}" == X'RHEL' ]; then + yum -y install $@ + elif [ X"${DISTRO}" == X'DEBIAN' -o X"${DISTRO}" == X'UBUNTU' ]; then + apt-get install -y $@ + elif [ X"${DISTRO}" == X'FREEBSD' ]; then + cd /usr/ports/$@ && make install clean + elif [ X"${DISTRO}" == X'OPENBSD' ]; then + pkg_add -r $@ + else + echo "<< ERROR >> Please install package(s) manually: $@" + fi +} + +has_python_module() +{ + mod="$1" + ${CMD_PYTHON3} -c "import $mod" &>/dev/null + if [ X"$?" == X'0' ]; then + echo 'YES' + else + echo 'NO' + fi +} + +add_missing_parameter() +{ + # Usage: add_missing_parameter VARIABLE DEFAULT_VALUE [COMMENT] + var="${1}" + value="${2}" + shift 2 + comment="$@" + + if ! grep "^${var}" ${IRA_CONF_PY} &>/dev/null; then + if [ ! -z "${comment}" ]; then + echo "# ${comment}" >> ${IRA_CONF_PY} + fi + + if [ X"${value}" == X'True' -o X"${value}" == X'False' ]; then + echo "${var} = ${value}" >> ${IRA_CONF_PY} + else + # Value must be quoted as string. + echo "${var} = '${value}'" >> ${IRA_CONF_PY} + fi + fi +} + +# Remove all single quote and double quotes in string. +strip_quotes() +{ + # Read input from stdin + str="$(cat <&0)" + value="$(echo ${str} | tr -d '"' | tr -d "'")" + echo "${value}" +} + +get_iredadmin_setting() +{ + var="${1}" + value="$(grep "^${var}" ${IRA_CONF_PY} | awk '{print $NF}' | strip_quotes)" + echo "${value}" +} + +check_dot_my_cnf() +{ + if egrep '^backend.*(mysql|ldap)' ${IRA_CONF_PY} &>/dev/null; then + if [ ! -f ${MY_CNF} ]; then + echo "<<< ERROR >>> File ${MY_CNF} not found." + echo "<<< ERROR >>> Please add mysql root user and password in it like below, then run this script again." + cat </dev/null + if [ X"$?" != X'0' ]; then + echo "<<< ERROR >>> MySQL user name '${SQL_IREDADMIN_USER}' or password is incorrect in ${MY_CNF}, please double check." + exit 255 + fi + fi +} + +check_mlmmjadmin_installation +check_dot_my_cnf + +echo "* Detected Linux/BSD distribution: ${DISTRO}" +echo "* HTTP server root: ${HTTPD_SERVERROOT}" + +if [ -L ${IRA_ROOT_DIR} ]; then + export IRA_ROOT_REAL_DIR="$(readlink ${IRA_ROOT_DIR})" + echo "* Found iRedAdmin directory: ${IRA_ROOT_DIR}, symbol link of ${IRA_ROOT_REAL_DIR}" +else + echo "<<< ERROR >>> Directory (${IRA_ROOT_DIR}) is not a symbol link created by iRedMail. Exit." + exit 255 +fi + +# Copy config file +if [ -f ${IRA_CONF_PY} ]; then + echo "* Found iRedAdmin config file: ${IRA_CONF_PY}" +else + echo "<<< ERROR >>> Cannot find a valid config file (settings.py)." + exit 255 +fi + +# Check whether current directory is iRedAdmin +PWD="$(pwd)" +if ! echo ${PWD} | grep 'iRedAdmin-.*/tools' >/dev/null; then + echo "<<< ERROR >>> Cannot find new version of iRedAdmin in current directory. Exit." + exit 255 +fi + +# Copy current directory to Apache server root +dir_new_version="$(dirname ${PWD})" +name_new_version="$(basename ${dir_new_version})" +NEW_IRA_ROOT_DIR="${HTTPD_SERVERROOT}/${name_new_version}" +if [ -d ${NEW_IRA_ROOT_DIR} ]; then + COPY_FILES="${dir_new_version}/*" + COPY_DEST_DIR="${NEW_IRA_ROOT_DIR}" + #echo "<<< ERROR >>> Directory exist: ${NEW_IRA_ROOT_DIR}. Exit." + #exit 255 +else + COPY_FILES="${dir_new_version}" + COPY_DEST_DIR="${HTTPD_SERVERROOT}" +fi + +echo "* Copying new version to ${NEW_IRA_ROOT_DIR}" +cp -rf ${COPY_FILES} ${COPY_DEST_DIR} + +# Copy old config files +echo "* Copy ${IRA_CONF_PY}." +cp -p ${IRA_CONF_PY} ${NEW_IRA_ROOT_DIR}/ + +if [ -f ${IRA_CUSTOM_CONF_PY} ]; then + echo "* Copy ${IRA_CUSTOM_CONF_PY}." + cp -p ${IRA_CUSTOM_CONF_PY} ${NEW_IRA_ROOT_DIR} +fi + +# Copy hooks.py. It's ok if missing. +if [ -f ${IRA_ROOT_DIR}/hooks.py ]; then + echo "* Copy ${IRA_ROOT_DIR}/hooks.py." + cp -p ${IRA_ROOT_DIR}/hooks.py ${NEW_IRA_ROOT_DIR}/ &>/dev/null +fi + +# Copy custom files under 'tools/'. It's ok if missing. +cp -p ${IRA_ROOT_DIR}/tools/*.custom ${NEW_IRA_ROOT_DIR}/tools/ &>/dev/null +cp -p ${IRA_ROOT_DIR}/tools/*.last-time ${NEW_IRA_ROOT_DIR}/tools/ &>/dev/null + +# Template file renamed +if [ -f "${IRA_ROOT_DIR}/tools/notify_quarantined_recipients.custom.html" ]; then + echo "* Copy ${IRA_ROOT_DIR}/tools/notify_quarantined_recipients.custom.html" + cp -f ${IRA_ROOT_DIR}/tools/notify_quarantined_recipients.custom.html \ + ${NEW_IRA_ROOT_DIR}/tools/notify_quarantined_recipients.html.custom +fi + +# Copy favicon.ico and brand logo image. +for var in 'BRAND_FAVICON' 'BRAND_LOGO'; do + if grep "^${var}\>" ${IRA_CONF_PY} &>/dev/null; then + _file=$(grep "^${var}\>" ${IRA_CONF_PY} | awk '{print $NF}' | tr -d '"' | tr -d "'") + echo "* Copy file ${IRA_ROOT_DIR}/static/${_file}" + cp -f ${IRA_ROOT_DIR}/static/${_file} ${NEW_IRA_ROOT_DIR}/static/ + fi +done + +# iredadmin is now ran as a standalone uwsgi instance, we don't need uwsgi +# daemon service anymore. +_uwsgi_confs=' + /etc/uwsgi.d/iredadmin.ini + /etc/uwsgi-available/iredadmin.ini + /etc/uwsgi/apps-enabled/iredadmin.ini &>/dev/null + /etc/uwsgi/apps-available/iredadmin.ini &>/dev/null + /usr/local/etc/uwsgi/iredadmin.ini + /etc/uwsgi-enabled/iredadmin.ini &>/dev/null + /etc/uwsgi-available/iredadmin.ini &>/dev/null +' + +for f in ${_uwsgi_confs}; do + rm -f ${f} &>/dev/null +done + +# Remove 'uwsgi_XXX' from /etc/rc.conf on FreeBSD. +if [[ X"${DISTRO}" == X'FREEBSD' ]]; then + ${SYSRC} -x uwsgi_enable &>/dev/null + ${SYSRC} -x uwsgi_profiles &>/dev/null + ${SYSRC} -x uwsgi_iredadmin_flags &>/dev/null +fi + +# Update Nginx template file +export _restart_nginx='NO' +for f in ${NGINX_SNIPPET_CONF} ${NGINX_SNIPPET_CONF2} ${NGINX_SNIPPET_CONF3}; do + if [[ -f ${f} ]]; then + if grep 'unix:.*iredadmin.socket' ${f} &>/dev/null; then + export _restart_nginx='YES' + perl -pi -e 's#uwsgi_pass unix:.*iredadmin.socket;#uwsgi_pass 127.0.0.1:7791;#g' ${f} + fi + fi +done + +if [[ X"${_restart_nginx}" == X'YES' ]]; then + restart_service nginx +fi + +# Update uwsgi ini config file +if [ -d ${NEW_IRA_ROOT_DIR}/rc_scripts/uwsgi ]; then + perl -pi -e 's#^chdir = (.*)#chdir = $ENV{HTTPD_SERVERROOT}/iredadmin#g' ${NEW_IRA_ROOT_DIR}/rc_scripts/uwsgi/*.ini +fi + +# Copy rc script or systemd service file +if [ X"${USE_SYSTEMD}" == X'YES' ]; then + echo "* Remove existing systemd service files." + rm -f ${SYSTEMD_SERVICE_DIR}/iredadmin.service &>/dev/null + rm -f ${SYSTEMD_SERVICE_DIR2}/iredadmin.service &>/dev/null + rm -f ${SYSTEMD_SERVICE_USER_DIR}/iredadmin.service &>/dev/null + + echo "* Copy systemd service file: ${SYSTEMD_SERVICE_DIR}/iredadmin.service." + if [ X"${DISTRO}" == X'RHEL' ]; then + cp -f ${NEW_IRA_ROOT_DIR}/rc_scripts/systemd/rhel${DISTRO_VERSION}.service ${SYSTEMD_SERVICE_DIR}/iredadmin.service + perl -pi -e 's#/opt/www#$ENV{HTTPD_SERVERROOT}#g' ${SYSTEMD_SERVICE_DIR}/iredadmin.service + perl -pi -e 's#/usr/local/bin/uwsgi#$ENV{CMD_UWSGI}#g' ${SYSTEMD_SERVICE_DIR}/iredadmin.service + + if [[ X"${UWSGI_HAS_PLUGINS}" == X"NO" ]]; then + _ini_file="${NEW_IRA_ROOT_DIR}/rc_scripts/uwsgi/rhel${DISTRO_VERSION}.ini" + if [[ -f ${_ini_file} ]]; then + perl -pi -e 's#^(plugins.*)##g' ${_ini_file} + fi + fi + elif [ X"${DISTRO}" == X'DEBIAN' -o X"${DISTRO}" == X'UBUNTU' ]; then + cp -f ${NEW_IRA_ROOT_DIR}/rc_scripts/systemd/debian.service ${SYSTEMD_SERVICE_DIR}/iredadmin.service + perl -pi -e 's#/opt/www#$ENV{HTTPD_SERVERROOT}#g' ${SYSTEMD_SERVICE_DIR}/iredadmin.service + fi + + chmod -R 0644 ${SYSTEMD_SERVICE_DIR}/iredadmin.service + systemctl daemon-reload &>/dev/null +else + if [ X"${DISTRO}" == X"FREEBSD" ]; then + cp ${NEW_IRA_ROOT_DIR}/rc_scripts/iredadmin.freebsd /usr/local/etc/rc.d/iredadmin + perl -pi -e 's#/opt/www#$ENV{HTTPD_SERVERROOT}#g' /usr/local/etc/rc.d/iredadmin + + # Remove 'uwsgi_iredadmin_flags=' in /etc/rc.conf.local + if [ -f /etc/rc.conf.local ]; then + perl -pi -e 's#^uwsgi_iredadminflags=.*##g' /etc/rc.conf.local + fi + elif [ X"${DISTRO}" == X'OPENBSD' ]; then + cp ${NEW_IRA_ROOT_DIR}/rc_scripts/iredadmin.openbsd ${DIR_RC_SCRIPTS}/iredadmin + perl -pi -e 's#/opt/www#$ENV{HTTPD_SERVERROOT}#g' /etc/rc.d/iredadmin + + cp -f ${NEW_IRA_ROOT_DIR}/rc_scripts/iredadmin.openbsd /etc/rc.d/iredadmin + chmod 0755 /etc/rc.d/iredadmin + + # Remove 'uwsgi_flags=' in /etc/rc.conf.local + if [ -f /etc/rc.conf.local ]; then + perl -pi -e 's#^uwsgi_flags=.*iredadmin.*##g' /etc/rc.conf.local + fi + fi +fi + +# Set owner and permission. +chown -R ${IRA_HTTPD_USER}:${IRA_HTTPD_GROUP} ${NEW_IRA_ROOT_DIR} +chmod -R 0555 ${NEW_IRA_ROOT_DIR} +chmod 0400 ${NEW_IRA_ROOT_DIR}/settings.py + +echo "* Removing old symbol link ${IRA_ROOT_DIR}" +rm -f ${IRA_ROOT_DIR} + +echo "* Creating symbol link ${IRA_ROOT_DIR} to ${NEW_IRA_ROOT_DIR}" +cd ${HTTPD_SERVERROOT} +ln -s ${name_new_version} iredadmin + +# Add missing setting parameters. +if grep 'amavisd_enable_logging.*True.*' ${IRA_CONF_PY} &>/dev/null; then + add_missing_parameter 'amavisd_enable_policy_lookup' True 'Enable per-recipient spam policy, white/blacklist.' +else + add_missing_parameter 'amavisd_enable_policy_lookup' False 'Enable per-recipient spam policy, white/blacklist.' +fi + +if ! grep '^iredapd_' ${IRA_CONF_PY} &>/dev/null; then + add_missing_parameter 'iredapd_enabled' True 'Enable iRedAPD integration.' + + # Get iredapd db password from /opt/iredapd/settings.py. + if [ -f /opt/iredapd/settings.py ]; then + grep '^iredapd_db_' /opt/iredapd/settings.py >> ${IRA_CONF_PY} + perl -pi -e 's#iredapd_db_server#iredapd_db_host#g' ${IRA_CONF_PY} + else + # Check backend. + if egrep '^backend.*pgsql' ${IRA_CONF_PY} &>/dev/null; then + export IREDAPD_DB_PORT='5432' + else + export IREDAPD_DB_PORT='3306' + fi + + add_missing_parameter 'iredapd_db_host' '127.0.0.1' + add_missing_parameter 'iredapd_db_port' ${IREDAPD_DB_PORT} + add_missing_parameter 'iredapd_db_name' 'iredapd' + add_missing_parameter 'iredapd_db_user' 'iredapd' + add_missing_parameter 'iredapd_db_password' 'password' + fi +fi +perl -pi -e 's#iredapd_db_server#iredapd_db_host#g' ${IRA_CONF_PY} + +if ! grep '^fail2ban_' ${IRA_CONF_PY} &>/dev/null; then + # Try to get password of SQL user `fail2ban`. + if egrep '^backend.*(mysql|ldap)' ${IRA_CONF_PY} &>/dev/null; then + _my_cnf='/root/.my.cnf-fail2ban' + if [ -f ${_my_cnf} ]; then + _host="$(grep '^host=' ${_my_cnf} | awk -F'host=' '{print $2}' | strip_quotes)" + _port="$(grep '^port=' ${_my_cnf} | awk -F'port=' '{print $2}' | strip_quotes)" + _user="$(grep '^user=' ${_my_cnf} | awk -F'user=' '{print $2}' | strip_quotes)" + _password="$(grep '^password=' ${_my_cnf} | awk -F'password=' '{print $2}' | strip_quotes)" + + [ X"${_host}" == X'' ] && _host='127.0.0.1' + [ X"${_port}" == X'' ] && _port='3306' + fi + elif egrep '^backend.*pgsql' ${IRA_CONF_PY} &>/dev/null; then + # Absolute path to ~/.pgpass + # - RHEL: /var/lib/pgsql/.pgpass + # - Debian/Ubuntu: /var/lib/postgresql/.pgpass + # - FreeBSD: /var/db/postgres/.pgpass + # - OpenBSD: /var/postgresql/.pgpass + for dir in \ + /var/lib/pgsql \ + /var/lib/postgresql \ + /var/db/postgres \ + /var/postgresql; do + _pgpass="${dir}/.pgpass" + if [ -f ${_pgpass} ]; then + if grep ':fail2ban:' ${_pgpass} &>/dev/null; then + _host="127.0.0.1" + _port="5432" + _user="fail2ban" + _password="$(grep ':fail2ban:' ${_pgpass} | awk -F':' '{print $NF}')" + break + fi + fi + done + fi + + if [ X"${_host}" != X'' ] && \ + [ X"${_port}" != X'' ] && \ + [ X"${_user}" != X'' ] && \ + [ X"${_password}" != X'' ]; then + echo "* Enable Fail2ban SQL integration." + add_missing_parameter 'fail2ban_enabled' 'True' + add_missing_parameter 'fail2ban_db_host' "${_host}" + add_missing_parameter 'fail2ban_db_port' "${_port}" + add_missing_parameter 'fail2ban_db_name' "fail2ban" + add_missing_parameter 'fail2ban_db_user' "${_user}" + add_missing_parameter 'fail2ban_db_password' "${_password}" + fi +fi + +# FreeBSD uses /var/run/log for syslog. +if [ X"${DISTRO}" == X'FREEBSD' ]; then + add_missing_parameter 'SYSLOG_SERVER' '/var/run/log' +fi + +# +# Enable mlmmj integration +# +if [ -e /opt/mlmmjadmin ]; then + echo "* Enable mlmmj integration." + # Force to use backend `bk_none`. + perl -pi -e 's#^(backend_api).*#${1} = "bk_none"#g' /opt/mlmmjadmin/settings.py + + if egrep '^backend.*(ldap)' ${IRA_CONF_PY} &>/dev/null; then + perl -pi -e 's#^(backend_cli).*#${1} = "bk_iredmail_ldap"#g' /opt/mlmmjadmin/settings.py + else + perl -pi -e 's#^(backend_cli).*#${1} = "bk_iredmail_sql"#g' /opt/mlmmjadmin/settings.py + fi + + # Add parameter `mlmmjadmin_api_auth_token` if missing + if ! grep '^mlmmjadmin_api_auth_token' ${IRA_CONF_PY} >/dev/null; then + # Get first api auth token + token=$(grep '^api_auth_tokens' /opt/mlmmjadmin/settings.py | awk -F"[=\']" '{print $3}' | tr -d '\n') + echo -e "\nmlmmjadmin_api_auth_token = '${token}'" >> ${IRA_CONF_PY} + fi + + echo "* Restarting service: mlmmjadmin." + restart_service mlmmjadmin +fi + +# Change old parameter names to the new ones: +# +# - ADDITION_USER_SERVICES -> ADDITIONAL_ENABLED_USER_SERVICES +# - LDAP_SERVER_NAME -> LDAP_SERVER_PRODUCT_NAME +perl -pi -e 's#ADDITION_USER_SERVICES#ADDITIONAL_ENABLED_USER_SERVICES#g' ${IRA_CONF_PY} +perl -pi -e 's#LDAP_SERVER_NAME#LDAP_SERVER_PRODUCT_NAME#g' ${IRA_CONF_PY} + +# Remove deprecated setting: ENABLE_SELF_SERVICE, it's now a per-domain setting. +perl -pi -e 's#^(ENABLE_SELF_SERVICE.*)##g' ${IRA_CONF_PY} + + +# Dependent packages. +export REQUIRED_PKGS="" +export PIP3_MODS="" +# Python modules. +export PKG_PY_PIP='python3-pip' +export PKG_PY_LDAP='python3-ldap' +export PKG_PY_MYSQL='python3-pymysql' +export PKG_PY_PGSQL='python3-psycopg2' +export PKG_PY_JSON='python3-simplejson' +export PKG_PY_DNS='python3-dnspython' +export PKG_PY_REQUESTS='python3-requests' +export PKG_PY_JINJA='python3-jinja2' +# Python modules installed with pip3: uwsgi. + +if [ X"${DISTRO}" == X'RHEL' ]; then + if [ X"${DISTRO_VERSION}" == X'7' ]; then + export PKG_PY_MYSQL='python36-PyMySQL' + export PKG_PY_JSON='python36-simplejson' + export PKG_PY_JINJA='python36-jinja2' + export REQUIRED_PKGS="${REQUIRED_PKGS} uwsgi uwsgi-plugin-python36 uwsgi-plugin-syslog" + + if rpm -q mod_wsgi &>/dev/null; then + remove_pkg mod_wsgi + export REQUIRED_PKGS="${REQUIRED_PKGS} python3-mod_wsgi" + fi + + else + if [ ! -x ${CMD_UWSGI} ]; then + # gcc is required to install uwsgi. + export REQUIRED_PKGS="${REQUIRED_PKGS} python3-devel python3-pip gcc" + export PIP3_MODS="${PIP3_MODS} uwsgi" + fi + fi + + export PKG_PY_DNS='python3-dns' +elif [ X"${DISTRO}" == X'DEBIAN' -o X"${DISTRO}" == X'UBUNTU' ]; then + export REQUIRED_PKGS="${REQUIRED_PKGS} uwsgi-core uwsgi-plugin-python3" + + if [ X"${DISTRO_VERSION}" == X'9' ]; then + export PKG_PY_LDAP='python3-pyldap' + else + export PKG_PY_LDAP='python3-ldap' + fi +elif [ X"${DISTRO}" == X'OPENBSD' ]; then + export PKG_PY_PIP='py3-pip' + export PKG_PY_JSON='py3-simplejson' + export PKG_PY_DNS='py3-dnspython' + export PKG_PY_REQUESTS='py3-requests' + export PKG_PY_JINJA='py3-jinja2' + if [ X"${DISTRO_VERSION}" == X'6.6' -o X"${DISTRO_VERSION}" == X'6.7' ]; then + export PKG_PY_MYSQL='py3-mysqlclient' + else + export PKG_PY_MYSQL='py3-pymysql' + fi + + if [ ! -x ${CMD_UWSGI} ]; then + export PIP3_MODS="${PIP3_MODS} uwsgi" + fi +elif [ X"${DISTRO}" == X'FREEBSD' ]; then + export PKG_PY_PIP='devel/py-pip' + export PKG_UWSGI="www/uwsgi" + export PKG_PY_JSON='devel/py-simplejson' + export PKG_PY_DNS='dns/py-dnspython' + export PKG_PY_REQUESTS='www/py-requests' + export PKG_PY_JINJA='devel/py-Jinja2' + + if [ ! -x ${CMD_UWSGI} ]; then + export REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_UWSGI}" + fi +fi + +echo "* Check and install required packages." +if egrep '^backend.*ldap' ${IRA_CONF_PY} &>/dev/null; then + [ X"$(has_python_module ldap)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_LDAP}" + [ X"$(has_python_module pymysql)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_MYSQL}" +elif egrep '^backend.*mysql' ${IRA_CONF_PY} &>/dev/null; then + [ X"$(has_python_module pymysql)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_MYSQL}" +elif egrep '^backend.*pgsql' ${IRA_CONF_PY} &>/dev/null; then + [ X"$(has_python_module psycopg2)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_PGSQL}" +fi +[ X"$(has_python_module pip)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_PIP}" +[ X"$(has_python_module simplejson)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_JSON}" +[ X"$(has_python_module dns)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_DNS}" +[ X"$(has_python_module requests)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_REQUESTS}" +if [ X"$(has_python_module web)" == X'NO' ]; then + PIP3_MODS="${PIP3_MODS} web.py>=0.61" +else # Verify module version. + _webpy_ver=$(${CMD_PYTHON3} -c "import web; print(web.__version__)") + if echo ${_webpy_ver} | grep '^0\.[45]' &>/dev/null; then + PIP3_MODS="${PIP3_MODS} web.py>=0.61" + fi +fi +[ X"$(has_python_module jinja2)" == X'NO' ] && REQUIRED_PKGS="${REQUIRED_PKGS} ${PKG_PY_JINJA}" + +if [ X"${REQUIRED_PKGS}" != X'' ]; then + install_pkg ${REQUIRED_PKGS} + if [ X"$?" != X'0' ]; then + echo "Package installation failed, please check console output and fix it manually." + exist 255 + fi +fi + +if [ X"${PIP3_MODS}" != X'' ]; then + ${CMD_PIP3} install -U ${PIP3_MODS} + if [ X"$?" != X'0' ]; then + echo "Package installation failed, please check console output and fix it manually." + exist 255 + fi +fi + +#------------------------------------------ +# Add new SQL tables, drop deprecated ones. +# +export ira_db_host="$(get_iredadmin_setting 'iredadmin_db_host')" +export ira_db_port="$(get_iredadmin_setting 'iredadmin_db_port')" +export ira_db_name="$(get_iredadmin_setting 'iredadmin_db_name')" +export ira_db_user="$(get_iredadmin_setting 'iredadmin_db_user')" +export ira_db_password="$(get_iredadmin_setting 'iredadmin_db_password')" + +# +# Update sql tables +# +psql_conn="psql -h ${ira_db_host} \ + -p ${ira_db_port} \ + -U ${ira_db_user} \ + -d ${ira_db_name}" + +if egrep '^backend.*(mysql|ldap)' ${IRA_CONF_PY} &>/dev/null; then + echo "* Check SQL tables, and add missed ones - if there's any" + ${CMD_MYSQL} ${ira_db_name} -e "SOURCE ${IRA_ROOT_DIR}/SQL/iredadmin.mysql" + ${CMD_MYSQL} ${ira_db_name} -e "ALTER TABLE log MODIFY COLUMN msg TEXT;" + + # Add column `tracking.id`. + ${CMD_MYSQL} ${ira_db_name} -e "DESC tracking \G" | grep 'Field: id' &>/dev/null + if [ X"$?" != X'0' ]; then + ${CMD_MYSQL} ${ira_db_name} -e "ALTER TABLE tracking ADD COLUMN id BIGINT(20) UNSIGNED AUTO_INCREMENT PRIMARY KEY;" + fi + + # Set column `id` to `PRIMARY KEY` + _tables='deleted_mailboxes updatelog log tracking' + for _table in ${_tables}; do + ${CMD_MYSQL} ${ira_db_name} -e "DESC ${_table}" | grep '^id.*PRI.*auto_increment' &>/dev/null + + if [ X"$?" != X'0' ]; then + ${CMD_MYSQL} ${ira_db_name} -e "ALTER TABLE ${_table} ADD PRIMARY KEY (id)" + fi + done + +elif egrep '^backend.*pgsql' ${IRA_CONF_PY} &>/dev/null; then + export PGPASSWORD="${ira_db_password}" + + # Allow log.msg to store long text. + ${psql_conn} <' &>/dev/null + if [ X"$?" != X'0' ]; then + echo "* [SQL] Add new table: iredadmin.tracking." + + ${psql_conn} <' &>/dev/null + if [ X"$?" != X'0' ]; then + echo "* [SQL] Add new table: iredadmin.domain_ownership." + + ${psql_conn} <' &>/dev/null + if [ X"$?" != X'0' ]; then + echo "* [SQL] Add new table: iredadmin.newsletter_subunsub_confirms." + + _sql="$(cat ${IRA_ROOT_DIR}/SQL/snippets/newsletter_subunsub_confirms.pgsql)" + ${psql_conn} <' &>/dev/null + if [ X"$?" != X'0' ]; then + echo "* [SQL] Add new table: iredadmin.settings." + + _sql="$(cat ${IRA_ROOT_DIR}/SQL/snippets/settings.pgsql)" + ${psql_conn} </dev/null +if [[ ! -f ${CRON_FILE_ROOT} ]]; then + touch ${CRON_FILE_ROOT} &>/dev/null + chmod 0600 ${CRON_FILE_ROOT} &>/dev/null +fi + +# cron job: clean up database. +if ! grep 'iredadmin/tools/cleanup_db.py' ${CRON_FILE_ROOT} &>/dev/null; then + cat >> ${CRON_FILE_ROOT} </dev/null +EOF +fi + +# cron job: clean up database. +if ! grep 'iredadmin/tools/delete_mailboxes.py' ${CRON_FILE_ROOT} &>/dev/null; then + cat >> ${CRON_FILE_ROOT} </dev/null +fi + +# Delete all sessions to force admins to re-login. +cd ${NEW_IRA_ROOT_DIR}/tools/ +${CMD_PYTHON3} delete_sessions.py + +echo "* iRedAdmin has been successfully upgraded." +restart_web_service + +# Enable and restart service +enable_service iredadmin +restart_service iredadmin + +echo "* Upgrading completed." + +cat <>> If iRedAdmin doesn't work as expected, please post your issue in +<<< NOTE >>> our online support forum: http://www.iredmail.org/forum/ +EOF diff --git a/web/__init__.py b/web/__init__.py new file mode 100644 index 0000000..5df7f24 --- /dev/null +++ b/web/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""web.py: makes web apps (http://webpy.org)""" + +from . import ( # noqa: F401 + db, + debugerror, + form, + http, + httpserver, + net, + session, + template, + utils, + webapi, + wsgi, +) +from .application import * # noqa: F401,F403 +from .db import * # noqa: F401,F403 +from .debugerror import * # noqa: F401,F403 +from .http import * # noqa: F401,F403 +from .httpserver import * # noqa: F401,F403 +from .net import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 +from .webapi import * # noqa: F401,F403 +from .wsgi import * # noqa: F401,F403 + +__version__ = "0.62" +__author__ = [ + "Aaron Swartz ", + "Anand Chitipothu ", +] +__license__ = "public domain" +__contributors__ = "see http://webpy.org/changes" diff --git a/web/application.py b/web/application.py new file mode 100644 index 0000000..e953884 --- /dev/null +++ b/web/application.py @@ -0,0 +1,813 @@ +""" +Web application +(from web.py) +""" + +import itertools +import os +import sys +import traceback +import wsgiref.handlers +from importlib import reload +from inspect import isclass +from io import BytesIO +from urllib.parse import unquote, urlencode, urlparse + +from . import browser, httpserver, utils +from . import webapi as web +from . import wsgi +from .debugerror import debugerror +from .py3helpers import iteritems +from .utils import lstrips + +__all__ = [ + "application", + "auto_application", + "subdir_application", + "subdomain_application", + "loadhook", + "unloadhook", + "autodelegate", +] + + +class application: + """ + Application to delegate requests based on path. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): return "hello" + >>> + >>> app.request("/hello").data + 'hello' + """ + + # PY3DOCTEST: b'hello' + + def __init__(self, mapping=(), fvars={}, autoreload=None): + if autoreload is None: + autoreload = web.config.get("debug", False) + self.init_mapping(mapping) + self.fvars = fvars + self.processors = [] + + self.add_processor(loadhook(self._load)) + self.add_processor(unloadhook(self._unload)) + + if autoreload: + + def main_module_name(): + mod = sys.modules["__main__"] + file = getattr( + mod, "__file__", None + ) # make sure this works even from python interpreter + return file and os.path.splitext(os.path.basename(file))[0] + + def modname(fvars): + """find name of the module name from fvars.""" + file, name = fvars.get("__file__"), fvars.get("__name__") + if file is None or name is None: + return None + + if name == "__main__": + # Since the __main__ module can't be reloaded, the module has + # to be imported using its file name. + name = main_module_name() + return name + + mapping_name = utils.dictfind(fvars, mapping) + module_name = modname(fvars) + + def reload_mapping(): + """loadhook to reload mapping and fvars.""" + mod = __import__(module_name, None, None, [""]) + mapping = getattr(mod, mapping_name, None) + if mapping: + self.fvars = mod.__dict__ + self.init_mapping(mapping) + + self.add_processor(loadhook(Reloader())) + if mapping_name and module_name: + # when app is ran as part of a package, this puts the app into + # `sys.modules` correctly, otherwise the first change to the + # app module will not be picked up by Reloader + reload_mapping() + + self.add_processor(loadhook(reload_mapping)) + + # load __main__ module usings its filename, so that it can be reloaded. + if main_module_name() and "__main__" in sys.argv: + try: + __import__(main_module_name()) + except ImportError: + pass + + def _load(self): + web.ctx.app_stack.append(self) + + def _unload(self): + web.ctx.app_stack = web.ctx.app_stack[:-1] + + if web.ctx.app_stack: + # this is a sub-application, revert ctx to earlier state. + oldctx = web.ctx.get("_oldctx") + if oldctx: + web.ctx.home = oldctx.home + web.ctx.homepath = oldctx.homepath + web.ctx.path = oldctx.path + web.ctx.fullpath = oldctx.fullpath + + def _cleanup(self): + # Threads can be recycled by WSGI servers. + # Clearing up all thread-local state to avoid interefereing with subsequent requests. + utils.ThreadedDict.clear_all() + + def init_mapping(self, mapping): + self.mapping = list(utils.group(mapping, 2)) + + def add_mapping(self, pattern, classname): + self.mapping.append((pattern, classname)) + + def add_processor(self, processor): + """ + Adds a processor to the application. + + >>> urls = ("/(.*)", "echo") + >>> app = application(urls, globals()) + >>> class echo: + ... def GET(self, name): return name + ... + >>> + >>> def hello(handler): return "hello, " + handler() + ... + >>> app.add_processor(hello) + >>> app.request("/web.py").data + 'hello, web.py' + """ + # PY3DOCTEST: b'hello, web.py' + self.processors.append(processor) + + def request( + self, + localpart="/", + method="GET", + data=None, + host="0.0.0.0:8080", + headers=None, + https=False, + **kw, + ): + """Makes request to this application for the specified path and method. + Response will be a storage object with data, status and headers. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): + ... web.header('Content-Type', 'text/plain') + ... return "hello" + ... + >>> response = app.request("/hello") + >>> response.data + 'hello' + >>> response.status + '200 OK' + >>> response.headers['Content-Type'] + 'text/plain' + + To use https, use https=True. + + >>> urls = ("/redirect", "redirect") + >>> app = application(urls, globals()) + >>> class redirect: + ... def GET(self): raise web.seeother("/foo") + ... + >>> response = app.request("/redirect") + >>> response.headers['Location'] + 'http://0.0.0.0:8080/foo' + >>> response = app.request("/redirect", https=True) + >>> response.headers['Location'] + 'https://0.0.0.0:8080/foo' + + The headers argument specifies HTTP headers as a mapping object + such as a dict. + + >>> urls = ('/ua', 'uaprinter') + >>> class uaprinter: + ... def GET(self): + ... return 'your user-agent is ' + web.ctx.env['HTTP_USER_AGENT'] + ... + >>> app = application(urls, globals()) + >>> app.request('/ua', headers = { + ... 'User-Agent': 'a small jumping bean/1.0 (compatible)' + ... }).data + 'your user-agent is a small jumping bean/1.0 (compatible)' + + """ + # PY3DOCTEST: b'hello' + # PY3DOCTEST: b'your user-agent is a small jumping bean/1.0 (compatible)' + _p = urlparse(localpart) + path = _p.path + maybe_query = _p.query + + query = maybe_query or "" + + if "env" in kw: + env = kw["env"] + else: + env = {} + env = dict( + env, + HTTP_HOST=host, + REQUEST_METHOD=method, + PATH_INFO=path, + QUERY_STRING=query, + HTTPS=str(https), + ) + headers = headers or {} + + for k, v in headers.items(): + env["HTTP_" + k.upper().replace("-", "_")] = v + + if "HTTP_CONTENT_LENGTH" in env: + env["CONTENT_LENGTH"] = env.pop("HTTP_CONTENT_LENGTH") + + if "HTTP_CONTENT_TYPE" in env: + env["CONTENT_TYPE"] = env.pop("HTTP_CONTENT_TYPE") + + if method not in ["HEAD", "GET"]: + data = data or "" + + if isinstance(data, dict): + q = urlencode(data) + else: + q = data + + env["wsgi.input"] = BytesIO(q.encode("utf-8")) + # if not env.get('CONTENT_TYPE', '').lower().startswith('multipart/') and 'CONTENT_LENGTH' not in env: + if "CONTENT_LENGTH" not in env: + env["CONTENT_LENGTH"] = len(q) + response = web.storage() + + def start_response(status, headers): + response.status = status + response.headers = dict(headers) + response.header_items = headers + + data = self.wsgifunc()(env, start_response) + response.data = b"".join(data) + return response + + def browser(self): + return browser.AppBrowser(self) + + def handle(self): + fn, args = self._match(self.mapping, web.ctx.path) + return self._delegate(fn, self.fvars, args) + + def handle_with_processors(self): + def process(processors): + try: + if processors: + p, processors = processors[0], processors[1:] + return p(lambda: process(processors)) + else: + return self.handle() + except web.HTTPError: + raise + except (KeyboardInterrupt, SystemExit): + raise + except: + print(traceback.format_exc(), file=web.debug) + raise self.internalerror() + + # processors must be applied in the reverse order. (??) + return process(self.processors) + + def wsgifunc(self, *middleware): + """Returns a WSGI-compatible function for this application.""" + + def peep(iterator): + """Peeps into an iterator by doing an iteration + and returns an equivalent iterator. + """ + # wsgi requires the headers first + # so we need to do an iteration + # and save the result for later + try: + firstchunk = next(iterator) + except StopIteration: + firstchunk = "" + + return itertools.chain([firstchunk], iterator) + + def wsgi(env, start_resp): + # clear threadlocal to avoid interference of previous requests + self._cleanup() + + self.load(env) + try: + # allow uppercase methods only + if web.ctx.method.upper() != web.ctx.method: + raise web.nomethod() + + result = self.handle_with_processors() + if result and hasattr(result, "__next__"): + result = peep(result) + else: + result = [result] + except web.HTTPError as e: + result = [e.data] + + def build_result(result): + for r in result: + if isinstance(r, bytes): + yield r + else: + yield str(r).encode("utf-8") + + result = build_result(result) + + status, headers = web.ctx.status, web.ctx.headers + start_resp(status, headers) + + def cleanup(): + self._cleanup() + yield b"" # force this function to be a generator + + return itertools.chain(result, cleanup()) + + for m in middleware: + wsgi = m(wsgi) + + return wsgi + + def run(self, *middleware): + """ + Starts handling requests. If called in a CGI or FastCGI context, it will follow + that protocol. If called from the command line, it will start an HTTP + server on the port named in the first command line argument, or, if there + is no argument, on port 8080. + + `middleware` is a list of WSGI middleware which is applied to the resulting WSGI + function. + """ + return wsgi.runwsgi(self.wsgifunc(*middleware)) + + def stop(self): + """Stops the http server started by run.""" + if httpserver.server: + httpserver.server.stop() + httpserver.server = None + + def cgirun(self, *middleware): + """ + Return a CGI handler. This is mostly useful with Google App Engine. + There you can just do: + + main = app.cgirun() + """ + wsgiapp = self.wsgifunc(*middleware) + + try: + from google.appengine.ext.webapp.util import run_wsgi_app + + return run_wsgi_app(wsgiapp) + except ImportError: + # we're not running from within Google App Engine + return wsgiref.handlers.CGIHandler().run(wsgiapp) + + def gaerun(self, *middleware): + """ + Starts the program in a way that will work with Google app engine, + no matter which version you are using (2.5 / 2.7) + + If it is 2.5, just normally start it with app.gaerun() + + If it is 2.7, make sure to change the app.yaml handler to point to the + global variable that contains the result of app.gaerun() + + For example: + + in app.yaml (where code.py is where the main code is located) + + handlers: + - url: /.* + script: code.app + + Make sure that the app variable is globally accessible + """ + wsgiapp = self.wsgifunc(*middleware) + try: + # check what version of python is running + version = sys.version_info[:2] + major = version[0] + minor = version[1] + + if major != 2: + raise OSError("Google App Engine only supports python 2.5 and 2.7") + + # if 2.7, return a function that can be run by gae + if minor == 7: + return wsgiapp + # if 2.5, use run_wsgi_app + elif minor == 5: + from google.appengine.ext.webapp.util import run_wsgi_app + + return run_wsgi_app(wsgiapp) + else: + raise OSError("Not a supported platform, use python 2.5 or 2.7") + except ImportError: + return wsgiref.handlers.CGIHandler().run(wsgiapp) + + def load(self, env): + """Initializes ctx using env.""" + ctx = web.ctx + ctx.clear() + ctx.status = "200 OK" + ctx.headers = [] + ctx.output = "" + ctx.environ = ctx.env = env + ctx.host = env.get("HTTP_HOST") + + if env.get("wsgi.url_scheme") in ["http", "https"]: + ctx.protocol = env["wsgi.url_scheme"] + elif env.get("HTTPS", "").lower() in ["on", "true", "1"]: + ctx.protocol = "https" + else: + ctx.protocol = "http" + ctx.homedomain = ctx.protocol + "://" + env.get("HTTP_HOST", "[unknown]") + ctx.homepath = os.environ.get("REAL_SCRIPT_NAME", env.get("SCRIPT_NAME", "")) + ctx.home = ctx.homedomain + ctx.homepath + # @@ home is changed when the request is handled to a sub-application. + # @@ but the real home is required for doing absolute redirects. + ctx.realhome = ctx.home + ctx.ip = env.get("REMOTE_ADDR") + ctx.method = env.get("REQUEST_METHOD") + try: + ctx.path = bytes(env.get("PATH_INFO"), "latin1").decode("utf8") + except UnicodeDecodeError: # If there are Unicode characters... + ctx.path = env.get("PATH_INFO") + + # http://trac.lighttpd.net/trac/ticket/406 requires: + if env.get("SERVER_SOFTWARE", "").startswith(("lighttpd/", "nginx/")): + ctx.path = lstrips(env.get("REQUEST_URI").split("?")[0], ctx.homepath) + # Apache and CherryPy webservers unquote urls but lighttpd and nginx do not. + # Unquote explicitly for lighttpd and nginx to make ctx.path uniform across + # all servers. + ctx.path = unquote(ctx.path) + + if env.get("QUERY_STRING"): + ctx.query = "?" + env.get("QUERY_STRING", "") + else: + ctx.query = "" + + ctx.fullpath = ctx.path + ctx.query + + for k, v in iteritems(ctx): + # convert all string values to unicode values and replace + # malformed data with a suitable replacement marker. + if isinstance(v, bytes): + ctx[k] = v.decode("utf-8", "replace") + + # status must always be str + ctx.status = "200 OK" + + ctx.app_stack = [] + + def _delegate(self, f, fvars, args=[]): + def handle_class(cls): + meth = web.ctx.method + if meth == "HEAD" and not hasattr(cls, meth): + meth = "GET" + if not hasattr(cls, meth): + raise web.nomethod(cls) + tocall = getattr(cls(), meth) + return tocall(*args) + + if f is None: + raise web.notfound() + elif isinstance(f, application): + return f.handle_with_processors() + elif isclass(f): + return handle_class(f) + elif isinstance(f, str): + if f.startswith("redirect "): + url = f.split(" ", 1)[1] + if web.ctx.method == "GET": + x = web.ctx.env.get("QUERY_STRING", "") + if x: + url += "?" + x + raise web.redirect(url) + elif "." in f: + mod, cls = f.rsplit(".", 1) + mod = __import__(mod, None, None, [""]) + cls = getattr(mod, cls) + else: + cls = fvars[f] + return handle_class(cls) + elif hasattr(f, "__call__"): + return f() + else: + return web.notfound() + + def _match(self, mapping, value): + for pat, what in mapping: + if isinstance(what, application): + if value.startswith(pat): + f = lambda: self._delegate_sub_application(pat, what) + return f, None + else: + continue + elif isinstance(what, str): + what, result = utils.re_subm(rf"^{pat}\Z", what, value) + else: + result = utils.re_compile(rf"^{pat}\Z").match(value) + + if result: # it's a match + return what, [x for x in result.groups()] + return None, None + + def _delegate_sub_application(self, dir, app): + """Deletes request to sub application `app` rooted at the directory `dir`. + The home, homepath, path and fullpath values in web.ctx are updated to mimic request + to the subapp and are restored after it is handled. + + @@Any issues with when used with yield? + """ + web.ctx._oldctx = web.storage(web.ctx) + web.ctx.home += dir + web.ctx.homepath += dir + web.ctx.path = web.ctx.path[len(dir) :] + web.ctx.fullpath = web.ctx.fullpath[len(dir) :] + return app.handle_with_processors() + + def get_parent_app(self): + if self in web.ctx.app_stack: + index = web.ctx.app_stack.index(self) + if index > 0: + return web.ctx.app_stack[index - 1] + + def notfound(self): + """Returns HTTPError with '404 not found' message""" + parent = self.get_parent_app() + if parent: + return parent.notfound() + else: + return web._NotFound() + + def internalerror(self): + """Returns HTTPError with '500 internal error' message""" + parent = self.get_parent_app() + if parent: + return parent.internalerror() + elif web.config.get("debug"): + return debugerror() + else: + return web._InternalError() + + +def with_metaclass(mcls): + def decorator(cls): + body = vars(cls).copy() + # clean out class body + body.pop("__dict__", None) + body.pop("__weakref__", None) + return mcls(cls.__name__, cls.__bases__, body) + + return decorator + + +class auto_application(application): + """Application similar to `application` but urls are constructed + automatically using metaclass. + + >>> app = auto_application() + >>> class hello(app.page): + ... def GET(self): return "hello, world" + ... + >>> class foo(app.page): + ... path = '/foo/.*' + ... def GET(self): return "foo" + >>> app.request("/hello").data + 'hello, world' + >>> app.request('/foo/bar').data + 'foo' + """ + + # PY3DOCTEST: b'hello, world' + # PY3DOCTEST: b'foo' + + def __init__(self): + application.__init__(self) + + class metapage(type): + def __init__(klass, name, bases, attrs): + type.__init__(klass, name, bases, attrs) + path = attrs.get("path", "/" + name) + + # path can be specified as None to ignore that class + # typically required to create a abstract base class. + if path is not None: + self.add_mapping(path, klass) + + @with_metaclass(metapage) # little hack needed for Py2 and Py3 compatibility + class page: + path = None + + self.page = page + + +# The application class already has the required functionality of subdir_application +subdir_application = application + + +class subdomain_application(application): + r""" + Application to delegate requests based on the host. + + >>> urls = ("/hello", "hello") + >>> app = application(urls, globals()) + >>> class hello: + ... def GET(self): return "hello" + >>> + >>> mapping = (r"hello\.example\.com", app) + >>> app2 = subdomain_application(mapping) + >>> app2.request("/hello", host="hello.example.com").data + 'hello' + >>> response = app2.request("/hello", host="something.example.com") + >>> response.status + '404 Not Found' + >>> response.data + 'not found' + """ + + # PY3DOCTEST: b'hello' + # PY3DOCTEST: b'not found' + + def handle(self): + host = web.ctx.host.split(":")[0] # strip port + fn, args = self._match(self.mapping, host) + return self._delegate(fn, self.fvars, args) + + def _match(self, mapping, value): + for pat, what in mapping: + if isinstance(what, str): + what, result = utils.re_subm("^" + pat + "$", what, value) + else: + result = utils.re_compile("^" + pat + "$").match(value) + + if result: # it's a match + return what, [x for x in result.groups()] + return None, None + + +def loadhook(h): + """ + Converts a load hook into an application processor. + + >>> app = auto_application() + >>> def f(): "something done before handling request" + ... + >>> app.add_processor(loadhook(f)) + """ + + def processor(handler): + h() + return handler() + + return processor + + +def unloadhook(h): + """ + Converts an unload hook into an application processor. + + >>> app = auto_application() + >>> def f(): "something done after handling request" + ... + >>> app.add_processor(unloadhook(f)) + """ + + def processor(handler): + try: + result = handler() + except: + # run the hook even when handler raises some exception + h() + raise + + if result and hasattr(result, "__next__"): + return wrap(result) + else: + h() + return result + + def wrap(result): + def next_hook(): + try: + return next(result) + except: + # call the hook at the and of iterator + h() + raise + + result = iter(result) + while True: + try: + yield next_hook() + except StopIteration: + return + + return processor + + +def autodelegate(prefix=""): + """ + Returns a method that takes one argument and calls the method named prefix+arg, + calling `notfound()` if there isn't one. Example: + + urls = ('/prefs/(.*)', 'prefs') + + class prefs: + GET = autodelegate('GET_') + def GET_password(self): pass + def GET_privacy(self): pass + + `GET_password` would get called for `/prefs/password` while `GET_privacy` for + `GET_privacy` gets called for `/prefs/privacy`. + + If a user visits `/prefs/password/change` then `GET_password(self, '/change')` + is called. + """ + + def internal(self, arg): + if "/" in arg: + first, rest = arg.split("/", 1) + func = prefix + first + args = ["/" + rest] + else: + func = prefix + arg + args = [] + + if hasattr(self, func): + try: + return getattr(self, func)(*args) + except TypeError: + raise web.notfound() + else: + raise web.notfound() + + return internal + + +class Reloader: + """Checks to see if any loaded modules have changed on disk and, + if so, reloads them. + """ + + """File suffix of compiled modules.""" + if sys.platform.startswith("java"): + SUFFIX = "$py.class" + else: + SUFFIX = ".pyc" + + def __init__(self): + self.mtimes = {} + + def __call__(self): + sys_modules = list(sys.modules.values()) + for mod in sys_modules: + self.check(mod) + + def check(self, mod): + # jython registers java packages as modules but they either + # don't have a __file__ attribute or its value is None + if not (mod and hasattr(mod, "__file__") and mod.__file__): + return + + try: + mtime = os.stat(mod.__file__).st_mtime + except OSError: + return + if mod.__file__.endswith(self.__class__.SUFFIX) and os.path.exists( + mod.__file__[:-1] + ): + mtime = max(os.stat(mod.__file__[:-1]).st_mtime, mtime) + + if mod not in self.mtimes: + self.mtimes[mod] = mtime + elif self.mtimes[mod] < mtime: + try: + reload(mod) + self.mtimes[mod] = mtime + except ImportError: + pass + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/web/browser.py b/web/browser.py new file mode 100644 index 0000000..d3bb601 --- /dev/null +++ b/web/browser.py @@ -0,0 +1,295 @@ +"""Browser to test web applications. +(from web.py) +""" +import os +import webbrowser +from http.cookiejar import CookieJar +from io import BytesIO +from urllib.parse import urljoin +from urllib.request import HTTPCookieProcessor, HTTPError, HTTPHandler, Request +from urllib.request import build_opener as urllib_build_opener +from urllib.response import addinfourl + +from .net import htmlunquote +from .utils import re_compile + +DEBUG = False + +__all__ = ["BrowserError", "Browser", "AppBrowser", "AppHandler"] + + +class BrowserError(Exception): + pass + + +class Browser: + def __init__(self): + self.cookiejar = CookieJar() + self._cookie_processor = HTTPCookieProcessor(self.cookiejar) + self.form = None + + self.url = "http://0.0.0.0:8080/" + self.path = "/" + + self.status = None + self.data = None + self._response = None + self._forms = None + + @property + def text(self): + return self.data.decode("utf-8") + + def reset(self): + """Clears all cookies and history.""" + self.cookiejar.clear() + + def build_opener(self): + """Builds the opener using (urllib2/urllib.request).build_opener. + Subclasses can override this function to prodive custom openers. + """ + return urllib_build_opener() + + def do_request(self, req): + if DEBUG: + print("requesting", req.get_method(), req.get_full_url()) + + opener = self.build_opener() + opener.add_handler(self._cookie_processor) + try: + self._response = opener.open(req) + except HTTPError as e: + self._response = e + + self.url = self._response.geturl() + self.path = Request(self.url).selector + self.data = self._response.read() + self.status = self._response.code + self._forms = None + self.form = None + + return self.get_response() + + def open(self, url, data=None, headers={}): + """Opens the specified url.""" + url = urljoin(self.url, url) + req = Request(url, data, headers) + + return self.do_request(req) + + def show(self): + """Opens the current page in real web browser.""" + f = open("page.html", "w") + f.write(self.data) + f.close() + + url = "file://" + os.path.abspath("page.html") + webbrowser.open(url) + + def get_response(self): + """Returns a copy of the current response.""" + return addinfourl( + BytesIO(self.data), self._response.info(), self._response.geturl() + ) + + def get_soup(self): + """Returns beautiful soup of the current document.""" + import BeautifulSoup + + return BeautifulSoup.BeautifulSoup(self.data) + + def get_text(self, e=None): + """Returns content of e or the current document as plain text.""" + e = e or self.get_soup() + return "".join( + [htmlunquote(c) for c in e.recursiveChildGenerator() if isinstance(c, str)] + ) + + def _get_links(self): + soup = self.get_soup() + return [a for a in soup.findAll(name="a")] + + def get_links( + self, text=None, text_regex=None, url=None, url_regex=None, predicate=None + ): + """Returns all links in the document.""" + return self._filter_links( + self._get_links(), + text=text, + text_regex=text_regex, + url=url, + url_regex=url_regex, + predicate=predicate, + ) + + def follow_link( + self, + link=None, + text=None, + text_regex=None, + url=None, + url_regex=None, + predicate=None, + ): + if link is None: + links = self._filter_links( + self.get_links(), + text=text, + text_regex=text_regex, + url=url, + url_regex=url_regex, + predicate=predicate, + ) + link = links and links[0] + + if link: + return self.open(link["href"]) + else: + raise BrowserError("No link found") + + def find_link( + self, text=None, text_regex=None, url=None, url_regex=None, predicate=None + ): + links = self._filter_links( + self.get_links(), + text=text, + text_regex=text_regex, + url=url, + url_regex=url_regex, + predicate=predicate, + ) + return links and links[0] or None + + def _filter_links( + self, + links, + text=None, + text_regex=None, + url=None, + url_regex=None, + predicate=None, + ): + predicates = [] + if text is not None: + predicates.append(lambda link: link.string == text) + if text_regex is not None: + predicates.append( + lambda link: re_compile(text_regex).search(link.string or "") + ) + if url is not None: + predicates.append(lambda link: link.get("href") == url) + if url_regex is not None: + predicates.append( + lambda link: re_compile(url_regex).search(link.get("href", "")) + ) + if predicate: + predicate.append(predicate) + + def f(link): + for p in predicates: + if not p(link): + return False + return True + + return [link for link in links if f(link)] + + def get_forms(self): + """Returns all forms in the current document. + The returned form objects implement the ClientForm.HTMLForm interface. + """ + if self._forms is None: + import ClientForm + + self._forms = ClientForm.ParseResponse( + self.get_response(), backwards_compat=False + ) + return self._forms + + def select_form(self, name=None, predicate=None, index=0): + """Selects the specified form.""" + forms = self.get_forms() + + if name is not None: + forms = [f for f in forms if f.name == name] + if predicate: + forms = [f for f in forms if predicate(f)] + + if forms: + self.form = forms[index] + return self.form + else: + raise BrowserError("No form selected.") + + def submit(self, **kw): + """submits the currently selected form.""" + if self.form is None: + raise BrowserError("No form selected.") + req = self.form.click(**kw) + return self.do_request(req) + + def __getitem__(self, key): + return self.form[key] + + def __setitem__(self, key, value): + self.form[key] = value + + +class AppBrowser(Browser): + """Browser interface to test web.py apps. + + b = AppBrowser(app) + b.open('/') + b.follow_link(text='Login') + + b.select_form(name='login') + b['username'] = 'joe' + b['password'] = 'secret' + b.submit() + + assert b.path == '/' + assert 'Welcome joe' in b.get_text() + """ + + def __init__(self, app): + Browser.__init__(self) + self.app = app + + def build_opener(self): + return urllib_build_opener(AppHandler(self.app)) + + +class AppHandler(HTTPHandler): + """urllib2 handler to handle requests using web.py application.""" + + handler_order = 100 + https_request = HTTPHandler.do_request_ + + def __init__(self, app): + self.app = app + + def http_open(self, req): + result = self.app.request( + localpart=req.selector, + method=req.get_method(), + host=req.host, + data=req.data, + headers=dict(req.header_items()), + https=(req.type == "https"), + ) + return self._make_response(result, req.get_full_url()) + + def https_open(self, req): + return self.http_open(req) + + def _make_response(self, result, url): + + data = "\r\n".join([f"{k}: {v}" for k, v in result.header_items]) + + import email + + headers = email.message_from_string(data) + + response = addinfourl(BytesIO(result.data), headers, url) + code, msg = result.status.split(None, 1) + response.code, response.msg = int(code), msg + return response diff --git a/web/contrib/__init__.py b/web/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/web/contrib/template.py b/web/contrib/template.py new file mode 100644 index 0000000..20cb4ed --- /dev/null +++ b/web/contrib/template.py @@ -0,0 +1,146 @@ +""" +Interface to various templating engines. +""" +import os.path + +__all__ = ["render_cheetah", "render_genshi", "render_mako", "cache"] + + +class render_cheetah: + """Rendering interface to Cheetah Templates. + + Example: + + render = render_cheetah('templates') + render.hello(name="cheetah") + """ + + def __init__(self, path): + # give error if Chetah is not installed + from Cheetah.Template import Template # noqa: F401 + + self.path = path + + def __getattr__(self, name): + from Cheetah.Template import Template + + path = os.path.join(self.path, name + ".html") + + def template(**kw): + t = Template(file=path, searchList=[kw]) + return t.respond() + + return template + + +class render_genshi: + """Rendering interface genshi templates. + Example: + + for xml/html templates. + + render = render_genshi(['templates/']) + render.hello(name='genshi') + + For text templates: + + render = render_genshi(['templates/'], type='text') + render.hello(name='genshi') + """ + + def __init__(self, *a, **kwargs): + from genshi.template import TemplateLoader + + self._type = kwargs.pop("type", None) + self._loader = TemplateLoader(*a, **kwargs) + + def __getattr__(self, name): + # Assuming all templates are html + path = name + ".html" + + if self._type == "text": + from genshi.template import TextTemplate + + cls = TextTemplate + type = "text" + else: + cls = None + type = self._type + + t = self._loader.load(path, cls=cls) + + def template(**kw): + stream = t.generate(**kw) + if type: + return stream.render(type) + else: + return stream.render() + + return template + + +class render_jinja: + """Rendering interface to Jinja2 Templates + + Example: + + render= render_jinja('templates') + render.hello(name='jinja2') + """ + + def __init__(self, *a, **kwargs): + extensions = kwargs.pop("extensions", []) + globals = kwargs.pop("globals", {}) + + from jinja2 import Environment, FileSystemLoader + + self._lookup = Environment( + loader=FileSystemLoader(*a, **kwargs), extensions=extensions + ) + self._lookup.globals.update(globals) + + def __getattr__(self, name): + # Assuming all templates end with .html + path = name + ".html" + t = self._lookup.get_template(path) + return t.render + + +class render_mako: + """Rendering interface to Mako Templates. + + Example: + + render = render_mako(directories=['templates']) + render.hello(name="mako") + """ + + def __init__(self, *a, **kwargs): + from mako.lookup import TemplateLookup + + self._lookup = TemplateLookup(*a, **kwargs) + + def __getattr__(self, name): + # Assuming all templates are html + path = name + ".html" + t = self._lookup.get_template(path) + return t.render + + +class cache: + """Cache for any rendering interface. + + Example: + + render = cache(render_cheetah("templates/")) + render.hello(name='cache') + """ + + def __init__(self, render): + self._render = render + self._cache = {} + + def __getattr__(self, name): + if name not in self._cache: + self._cache[name] = getattr(self._render, name) + return self._cache[name] diff --git a/web/db.py b/web/db.py new file mode 100644 index 0000000..e9008f3 --- /dev/null +++ b/web/db.py @@ -0,0 +1,1742 @@ +""" +Database API +(part of web.py) +""" +import ast +import datetime +import os +import re +import time +from urllib import parse as urlparse +from urllib.parse import unquote + +from .py3helpers import iteritems +from .utils import iters, safestr, safeunicode, storage, threadeddict + +try: + # db module can work independent of web.py + from .webapi import config, debug +except ImportError: + import sys + + debug = sys.stderr + config = storage() + +__all__ = [ + "UnknownParamstyle", + "UnknownDB", + "TransactionError", + "sqllist", + "sqlors", + "reparam", + "sqlquote", + "SQLQuery", + "SQLParam", + "sqlparam", + "SQLLiteral", + "sqlliteral", + "database", + "DB", +] + +TOKEN = "[ \\f\\t]*(\\\\\\r?\\n[ \\f\\t]*)*(#[^\\r\\n]*)?(((\\d+[jJ]|((\\d+\\.\\d*|\\.\\d+)([eE][-+]?\\d+)?|\\d+[eE][-+]?\\d+)[jJ])|((\\d+\\.\\d*|\\.\\d+)([eE][-+]?\\d+)?|\\d+[eE][-+]?\\d+)|(0[xX][\\da-fA-F]+[lL]?|0[bB][01]+[lL]?|(0[oO][0-7]+)|(0[0-7]*)[lL]?|[1-9]\\d*[lL]?))|((\\*\\*=?|>>=?|<<=?|<>|!=|//=?|[+\\-*/%&|^=<>]=?|~)|[][(){}]|(\\r?\\n|[:;.,`@]))|([uUbB]?[rR]?'[^\\n'\\\\]*(?:\\\\.[^\\n'\\\\]*)*'|[uUbB]?[rR]?\"[^\\n\"\\\\]*(?:\\\\.[^\\n\"\\\\]*)*\")|[a-zA-Z_]\\w*)" + +tokenprog = re.compile(TOKEN) + +# Supported db drivers. +pg_drivers = ("psycopg2",) +mysql_drivers = ("pymysql", "MySQLdb", "mysql.connector") +sqlite_drivers = ("sqlite3", "pysqlite2.dbapi2", "sqlite") + + +class UnknownDB(Exception): + """raised for unsupported dbms""" + + pass + + +class _ItplError(ValueError): + def __init__(self, text, pos): + ValueError.__init__(self) + self.text = text + self.pos = pos + + def __str__(self): + return "unfinished expression in %s at char %d" % (repr(self.text), self.pos) + + +class TransactionError(Exception): + pass + + +class UnknownParamstyle(Exception): + """ + raised for unsupported db paramstyles + + (currently supported: qmark, numeric, format, pyformat) + """ + + pass + + +class SQLParam: + """ + Parameter in SQLQuery. + + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")]) + >>> q + + >>> q.query() + 'SELECT * FROM test WHERE name=%s' + >>> q.values() + ['joe'] + """ + + __slots__ = ["value"] + + def __init__(self, value): + self.value = value + + def get_marker(self, paramstyle="pyformat"): + if paramstyle == "qmark": + return "?" + elif paramstyle == "numeric": + return ":1" + elif paramstyle is None or paramstyle in ["format", "pyformat"]: + return "%s" + raise UnknownParamstyle(paramstyle) + + def sqlquery(self): + return SQLQuery([self]) + + def __add__(self, other): + return self.sqlquery() + other + + def __radd__(self, other): + return other + self.sqlquery() + + def __str__(self): + return str(self.value) + + def __eq__(self, other): + return isinstance(other, SQLParam) and other.value == self.value + + def __repr__(self): + return "" % repr(self.value) + + +sqlparam = SQLParam + + +class SQLQuery: + """ + You can pass this sort of thing as a clause in any db function. + Otherwise, you can pass a dictionary to the keyword argument `vars` + and the function will call reparam for you. + + Internally, consists of `items`, which is a list of strings and + SQLParams, which get concatenated to produce the actual query. + """ + + __slots__ = ["items"] + + # tested in sqlquote's docstring + def __init__(self, items=None): + r"""Creates a new SQLQuery. + + >>> SQLQuery("x") + + >>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)]) + >>> q + + >>> q.query(), q.values() + ('SELECT * FROM test WHERE x=%s', [1]) + >>> SQLQuery(SQLParam(1)) + + """ + if items is None: + self.items = [] + elif isinstance(items, list): + self.items = items + elif isinstance(items, SQLParam): + self.items = [items] + elif isinstance(items, SQLQuery): + self.items = list(items.items) + else: + self.items = [items] + + # Take care of SQLLiterals + for i, item in enumerate(self.items): + if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral): + self.items[i] = item.value.v + + def append(self, value): + self.items.append(value) + + def __add__(self, other): + if isinstance(other, str): + items = [other] + elif isinstance(other, SQLQuery): + items = other.items + else: + return NotImplemented + return SQLQuery(self.items + items) + + def __radd__(self, other): + if isinstance(other, str): + items = [other] + elif isinstance(other, SQLQuery): + items = other.items + else: + return NotImplemented + return SQLQuery(items + self.items) + + def __iadd__(self, other): + if isinstance(other, (str, SQLParam)): + self.items.append(other) + elif isinstance(other, SQLQuery): + self.items.extend(other.items) + else: + return NotImplemented + return self + + def __len__(self): + return len(self.query()) + + def __eq__(self, other): + return isinstance(other, SQLQuery) and other.items == self.items + + def query(self, paramstyle=None): + """ + Returns the query part of the sql query. + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) + >>> q.query() + 'SELECT * FROM test WHERE name=%s' + >>> q.query(paramstyle='qmark') + 'SELECT * FROM test WHERE name=?' + """ + s = [] + for x in self.items: + if isinstance(x, SQLParam): + x = x.get_marker(paramstyle) + s.append(safestr(x)) + else: + x = safestr(x) + # automatically escape % characters in the query + # For backward compatibility, ignore escaping when the query + # looks already escaped + if paramstyle in ["format", "pyformat"]: + if "%" in x and "%%" not in x: + x = x.replace("%", "%%") + s.append(x) + return "".join(s) + + def values(self): + """ + Returns the values of the parameters used in the sql query. + >>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) + >>> q.values() + ['joe'] + """ + return [i.value for i in self.items if isinstance(i, SQLParam)] + + def join(items, sep=" ", prefix=None, suffix=None, target=None): + """ + Joins multiple queries. + + >>> SQLQuery.join(['a', 'b'], ', ') + + + Optionally, prefix and suffix arguments can be provided. + + >>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')') + + + If target argument is provided, the items are appended to target + instead of creating a new SQLQuery. + """ + if target is None: + target = SQLQuery() + + target_items = target.items + + if prefix: + target_items.append(prefix) + + for i, item in enumerate(items): + if i != 0 and sep != "": + target_items.append(sep) + if isinstance(item, SQLQuery): + target_items.extend(item.items) + elif item == "": # joins with empty strings + continue + else: + target_items.append(item) + + if suffix: + target_items.append(suffix) + return target + + join = staticmethod(join) + + def _str(self): + try: + return self.query() % tuple(sqlify(x) for x in self.values()) + except (ValueError, TypeError): + return self.query() + + def __str__(self): + return safestr(self._str()) + + def __unicode__(self): + return safeunicode(self._str()) + + def __repr__(self): + return "" % repr(str(self)) + + +class SQLLiteral: + """ + Protects a string from `sqlquote`. + + >>> sqlquote('NOW()') + + >>> sqlquote(SQLLiteral('NOW()')) + + """ + + def __init__(self, v): + self.v = v + + def __repr__(self): + return "" % self.v + + +sqlliteral = SQLLiteral + + +def _sqllist(values): + """ + >>> _sqllist([1, 2, 3]) + + >>> _sqllist(set([5, 1, 3, 2])) + + >>> _sqllist((5, 1, 3, 2, 2, 5)) + + """ + items = [] + items.append("(") + + if isinstance(values, set): + values = list(values) + elif isinstance(values, tuple): + values = list(set(values)) + + for i, v in enumerate(values): + if i != 0: + items.append(", ") + items.append(sqlparam(v)) + items.append(")") + return SQLQuery(items) + + +def reparam(string_, dictionary): + """ + Takes a string and a dictionary and interpolates the string + using values from the dictionary. Returns an `SQLQuery` for the result. + + >>> reparam("s = $s", dict(s=True)) + + >>> reparam("s IN $s", dict(s=[1, 2])) + + """ + return SafeEval().safeeval(string_, dictionary) + + +def sqlify(obj): + """ + converts `obj` to its proper SQL version + + >>> sqlify(None) + 'NULL' + >>> sqlify(True) + "'t'" + >>> sqlify(3) + '3' + """ + # because `1 == True and hash(1) == hash(True)` + # we have to do this the hard way... + + if obj is None: + return "NULL" + elif obj is True: + return "'t'" + elif obj is False: + return "'f'" + elif isinstance(obj, int): + return str(obj) + elif isinstance(obj, datetime.datetime): + return repr(obj.isoformat()) + else: + return repr(obj) + + +def sqllist(lst): + """ + Converts the arguments for use in something like a WHERE clause. + + >>> sqllist(['a', 'b']) + 'a, b' + >>> sqllist('a') + 'a' + """ + if isinstance(lst, str): + return lst + else: + return ", ".join(lst) + + +def sqlors(left, lst): + """ + `left is a SQL clause like `tablename.arg = ` + and `lst` is a list of values. Returns a reparam-style + pair featuring the SQL that ORs together the clause + for each item in the lst. + + >>> sqlors('foo = ', []) + + >>> sqlors('foo = ', [1]) + + >>> sqlors('foo = ', 1) + + >>> sqlors('foo = ', [1,2,3]) + + """ + if isinstance(lst, iters): + lst = list(lst) + ln = len(lst) + if ln == 0: + return SQLQuery("1=2") + if ln == 1: + lst = lst[0] + + if isinstance(lst, iters): + return SQLQuery( + ["("] + sum(([left, sqlparam(x), " OR "] for x in lst), []) + ["1=2)"] + ) + else: + return left + sqlparam(lst) + + +def sqlwhere(data, grouping=" AND "): + """ + Converts a two-tuple (key, value) iterable `data` to an SQL WHERE clause + `SQLQuery`. + + >>> sqlwhere((('cust_id', 2), ('order_id',3))) + + >>> sqlwhere((('order_id', 3), ('cust_id', 2)), grouping=', ') + + >>> sqlwhere((('a', 'a'), ('b', 'b'))).query() + 'a = %s AND b = %s' + """ + + return SQLQuery.join([k + " = " + sqlparam(v) for k, v in data], grouping) + + +def sqlquote(a): + """ + Ensures `a` is quoted properly for use in a SQL query. + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3]) + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote(set([3, 2, 3, 4])) + + >>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote((3, 2, 3, 4)) + + """ + if isinstance(a, (list, tuple, set)): + return _sqllist(a) + else: + return sqlparam(a).sqlquery() + + +class BaseResultSet: + """Base implementation of Result Set, the result of a db query.""" + + def __init__(self, cursor): + self.cursor = cursor + self.names = [x[0] for x in cursor.description] + self._index = 0 + + def list(self): + rows = [self._prepare_row(d) for d in self.cursor.fetchall()] + self._index += len(rows) + return rows + + def _prepare_row(self, row): + return storage(dict(zip(self.names, row))) + + def __iter__(self): + return self + + def __next__(self): + row = self.cursor.fetchone() + if row is None: + raise StopIteration() + self._index += 1 + return self._prepare_row(row) + + next = __next__ # for python 2.7 support + + def first(self, default=None): + """Returns the first row of this ResultSet or None when there are no + elements. + + If the optional argument default is specified, that is returned instead + of None when there are no elements. + """ + try: + return next(iter(self)) + except StopIteration: + return default + + def __getitem__(self, i): + # todo: slices + if i < self._index: + raise IndexError("already passed " + str(i)) + try: + while i > self._index: + next(self) + self._index += 1 + # now self._index == i + self._index += 1 + return next(self) + except StopIteration: + raise IndexError(str(i)) + + +class ResultSet(BaseResultSet): + """The result of a database query.""" + + def __len__(self): + return int(self.cursor.rowcount) + + +class SqliteResultSet(BaseResultSet): + """Result Set for sqlite. + + Same functionally as ResultSet except len is not supported. + """ + + def __init__(self, cursor): + BaseResultSet.__init__(self, cursor) + self._head = None + + def __next__(self): + if self._head is not None: + self._index += 1 + return self._head + else: + return super().__next__() + + def __bool__(self): + # The ResultSet class class doesn't need to support __bool__ explicitly + # because it has __len__. Since SqliteResultSet doesn't support len, + # we need to peep into the result to find if the result is empty of not. + if self._head is None: + try: + self._head = next(self) + self._index -= 1 # reset the index + except StopIteration: + return False + return True + + +class Transaction: + """Database transaction.""" + + def __init__(self, ctx): + self.ctx = ctx + self.transaction_count = transaction_count = len(ctx.transactions) + + class transaction_engine: + """Transaction Engine used in top level transactions.""" + + def do_transact(self): + ctx.commit(unload=False) + + def do_commit(self): + ctx.commit() + + def do_rollback(self): + ctx.rollback() + + class subtransaction_engine: + """Transaction Engine used in sub transactions.""" + + def query(self, q): + db_cursor = ctx.db.cursor() + ctx.db_execute(db_cursor, SQLQuery(q % transaction_count)) + + def do_transact(self): + self.query("SAVEPOINT webpy_sp_%s") + + def do_commit(self): + self.query("RELEASE SAVEPOINT webpy_sp_%s") + + def do_rollback(self): + self.query("ROLLBACK TO SAVEPOINT webpy_sp_%s") + + class dummy_engine: + """Transaction Engine used instead of subtransaction_engine + when sub transactions are not supported.""" + + do_transact = do_commit = do_rollback = lambda self: None + + if self.transaction_count: + # nested transactions are not supported in some databases + if self.ctx.get("ignore_nested_transactions"): + self.engine = dummy_engine() + else: + self.engine = subtransaction_engine() + else: + self.engine = transaction_engine() + + self.engine.do_transact() + self.ctx.transactions.append(self) + + def __enter__(self): + return self + + def __exit__(self, exctype, excvalue, traceback): + if exctype is not None: + self.rollback() + else: + self.commit() + + def commit(self): + if len(self.ctx.transactions) > self.transaction_count: + self.engine.do_commit() + self.ctx.transactions = self.ctx.transactions[: self.transaction_count] + + def rollback(self): + if len(self.ctx.transactions) > self.transaction_count: + self.engine.do_rollback() + self.ctx.transactions = self.ctx.transactions[: self.transaction_count] + + +class DB: + """Database""" + + def __init__(self, db_module, keywords): + """Creates a database.""" + # some DB implementations take optional parameter `driver` to use a + # specific driver module but it should not be passed to `connect`. + keywords.pop("driver", None) + + self.db_module = db_module + self.keywords = keywords + + self._ctx = threadeddict() + # flag to enable/disable printing queries + self.printing = config.get("debug_sql", config.get("debug", False)) + self.supports_multiple_insert = False + + try: + import dbutils # noqa: F401 + + # enable pooling if DBUtils module is available. + self.has_pooling = True + except ImportError: + self.has_pooling = False + + # Pooling can be disabled by passing pooling=False in the keywords. + self.has_pooling = self.keywords.pop("pooling", True) and self.has_pooling + + def _getctx(self): + if not self._ctx.get("db"): + self._load_context(self._ctx) + return self._ctx + + ctx = property(_getctx) + + def _load_context(self, ctx): + ctx.dbq_count = 0 + ctx.transactions = [] # stack of transactions + + if self.has_pooling: + ctx.db = self._connect_with_pooling(self.keywords) + else: + ctx.db = self._connect(self.keywords) + ctx.db_execute = self._db_execute + + if not hasattr(ctx.db, "commit"): + ctx.db.commit = lambda: None + + if not hasattr(ctx.db, "rollback"): + ctx.db.rollback = lambda: None + + def commit(unload=True): + # do db commit and release the connection if pooling is enabled. + ctx.db.commit() + if unload and self.has_pooling: + self._unload_context(self._ctx) + + def rollback(): + # do db rollback and release the connection if pooling is enabled. + ctx.db.rollback() + if self.has_pooling: + self._unload_context(self._ctx) + + ctx.commit = commit + ctx.rollback = rollback + + def _unload_context(self, ctx): + del ctx.db + + def _connect(self, keywords): + return self.db_module.connect(**keywords) + + def _connect_with_pooling(self, keywords): + def get_pooled_db(): + # In DBUtils 2.0.0, names were made pep8 compliant + # https://webwareforpython.github.io/DBUtils/changelog.html + from dbutils import pooled_db as PooledDB + + # In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` + # see Bug#122112 + + if PooledDB.__version__.split(".") < "0.9.3".split("."): + return PooledDB.PooledDB(dbapi=self.db_module, **keywords) + else: + return PooledDB.PooledDB(creator=self.db_module, **keywords) + + if getattr(self, "_pooleddb", None) is None: + self._pooleddb = get_pooled_db() + + return self._pooleddb.connection() + + def _db_cursor(self): + return self.ctx.db.cursor() + + def _param_marker(self): + """Returns parameter marker based on paramstyle attribute if this database.""" + style = getattr(self, "paramstyle", "pyformat") + + if style == "qmark": + return "?" + elif style == "numeric": + return ":1" + elif style in ["format", "pyformat"]: + return "%s" + raise UnknownParamstyle(style) + + def _db_execute(self, cur, sql_query): + """executes an sql query""" + self.ctx.dbq_count += 1 + + try: + a = time.time() + query, params = self._process_query(sql_query) + out = cur.execute(query, params) + b = time.time() + except: + if self.printing: + print("ERR:", str(sql_query), file=debug) + if self.ctx.transactions: + self.ctx.transactions[-1].rollback() + else: + self.ctx.rollback() + raise + + if self.printing: + print( + f"{round(b - a, 2)} ({self.ctx.dbq_count}): {str(sql_query)}", + file=debug, + ) + return out + + def _process_query(self, sql_query): + """Takes the SQLQuery object and returns query string and parameters.""" + paramstyle = getattr(self, "paramstyle", "pyformat") + query = sql_query.query(paramstyle) + params = sql_query.values() + return query, params + + def _where(self, where, vars): + if isinstance(where, int): + where = "id = " + sqlparam(where) + # @@@ for backward-compatibility + elif isinstance(where, (list, tuple)) and len(where) == 2: + where = SQLQuery(where[0], where[1]) + elif isinstance(where, dict): + where = self._where_dict(where) + elif isinstance(where, SQLQuery): + pass + else: + where = reparam(where, vars) + return where + + def _where_dict(self, where): + where_clauses = [] + + for k, v in sorted(iteritems(where), key=lambda t: t[0]): + where_clauses.append(k + " = " + sqlquote(v)) + if where_clauses: + return SQLQuery.join(where_clauses, " AND ") + else: + return None + + def query(self, sql_query, vars=None, processed=False, _test=False): + """ + Execute SQL query `sql_query` using dictionary `vars` to interpolate it. + If `processed=True`, `vars` is a `reparam`-style list to use + instead of interpolating. + + >>> db = DB(None, {}) + >>> db.query("SELECT * FROM foo", _test=True) + + >>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) + + >>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) + + """ + if vars is None: + vars = {} + + if not processed and not isinstance(sql_query, SQLQuery): + sql_query = reparam(sql_query, vars) + + if _test: + return sql_query + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, sql_query) + + if db_cursor.description: + out = self.create_result_set(db_cursor) + else: + out = db_cursor.rowcount + + if not self.ctx.transactions: + self.ctx.commit() + return out + + def create_result_set(self, cursor): + return ResultSet(cursor) + + def select( + self, + tables, + vars=None, + what="*", + where=None, + order=None, + group=None, + limit=None, + offset=None, + _test=False, + ): + """ + Selects `what` from `tables` with clauses `where`, `order`, + `group`, `limit`, and `offset`. Uses vars to interpolate. + Otherwise, each clause can be a SQLQuery. + + >>> db = DB(None, {}) + >>> db.select('foo', _test=True) + + >>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) + + >>> db.select('foo', where={'id': 5}, _test=True) + + """ + if vars is None: + vars = {} + + sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) + clauses = [ + self.gen_clause(sql, val, vars) + for sql, val in sql_clauses + if val is not None + ] + qout = SQLQuery.join(clauses) + + if _test: + return qout + + return self.query(qout, processed=True) + + def where( + self, + table, + what="*", + order=None, + group=None, + limit=None, + offset=None, + _test=False, + **kwargs, + ): + """ + Selects from `table` where keys are equal to values in `kwargs`. + + >>> db = DB(None, {}) + >>> db.where('foo', bar_id=3, _test=True) + + >>> db.where('foo', source=2, crust='dewey', _test=True) + + >>> db.where('foo', _test=True) + + """ + where = self._where_dict(kwargs) + return self.select( + table, + what=what, + order=order, + group=group, + limit=limit, + offset=offset, + _test=_test, + where=where, + ) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ("SELECT", what), + ("FROM", sqllist(tables)), + ("WHERE", where), + ("GROUP BY", group), + ("ORDER BY", order), + # The limit and offset could be the values provided by + # the end-user and are potentially unsafe. + # Using them as parameters to avoid any risk. + ("LIMIT", limit and SQLParam(limit).sqlquery()), + ("OFFSET", offset and SQLParam(offset).sqlquery()), + ) + + def gen_clause(self, sql, val, vars): + if isinstance(val, int): + if sql == "WHERE": + nout = "id = " + sqlquote(val) + else: + nout = SQLQuery(val) + # @@@ + elif isinstance(val, (list, tuple)) and len(val) == 2: + nout = SQLQuery(val[0], val[1]) # backwards-compatibility + elif sql == "WHERE" and isinstance(val, dict): + nout = self._where_dict(val) + elif isinstance(val, SQLQuery): + nout = val + else: + nout = reparam(val, vars) + + def xjoin(a, b): + if a and b: + return a + " " + b + else: + return a or b + + return xjoin(sql, nout) + + def insert(self, tablename, seqname=None, _test=False, **values): + """ + Inserts `values` into `tablename`. Returns current sequence ID. + Set `seqname` to the ID if it's not the default, or to `False` + if there isn't one. + + >>> db = DB(None, {}) + >>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True) + >>> q + + >>> q.query() + 'INSERT INTO foo (age, created, name) VALUES (%s, NOW(), %s)' + >>> q.values() + [2, 'bob'] + """ + + def q(x): + return "(" + x + ")" + + if values: + # needed for Py3 compatibility with the above doctests + sorted_values = sorted(values.items(), key=lambda t: t[0]) + + _keys = SQLQuery.join(map(lambda t: t[0], sorted_values), ", ") + _values = SQLQuery.join( + [sqlparam(v) for v in map(lambda t: t[1], sorted_values)], ", " + ) + sql_query = ( + "INSERT INTO %s " % tablename + q(_keys) + " VALUES " + q(_values) + ) + else: + sql_query = SQLQuery(self._get_insert_default_values_query(tablename)) + + if _test: + return sql_query + + db_cursor = self._db_cursor() + if seqname is not False: + sql_query = self._process_insert_query(sql_query, tablename, seqname) + + if isinstance(sql_query, tuple): + # for some databases, a separate query has to be made to find + # the id of the inserted row. + q1, q2 = sql_query + self._db_execute(db_cursor, q1) + self._db_execute(db_cursor, q2) + else: + self._db_execute(db_cursor, sql_query) + + try: + out = db_cursor.fetchone()[0] + except Exception: + out = None + + if not self.ctx.transactions: + self.ctx.commit() + + return out + + def _get_insert_default_values_query(self, table): + return "INSERT INTO %s DEFAULT VALUES" % table + + def multiple_insert(self, tablename, values, seqname=None, _test=False): + """ + Inserts multiple rows into `tablename`. The `values` must be a list of + dictionaries, one for each row to be inserted, each with the same set + of keys. Returns the list of ids of the inserted rows. + Set `seqname` to the ID if it's not the default, or to `False` + if there isn't one. + + >>> db = DB(None, {}) + >>> db.supports_multiple_insert = True + >>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}] + >>> db.multiple_insert('person', values=values, _test=True) + + """ + if not values: + return [] + + if not self.supports_multiple_insert: + out = [ + self.insert(tablename, seqname=seqname, _test=_test, **v) + for v in values + ] + if seqname is False: + return None + else: + return out + + keys = values[0].keys() + # @@ make sure all keys are valid + + for v in values: + if v.keys() != keys: + raise ValueError("Not all rows have the same keys") + + # enforce query order for the above doctest compatibility with Py3 + keys = sorted(keys) + + sql_query = SQLQuery( + "INSERT INTO {} ({}) VALUES ".format(tablename, ", ".join(keys)) + ) + + for i, row in enumerate(values): + if i != 0: + sql_query.append(", ") + SQLQuery.join( + [SQLParam(row[k]) for k in keys], + sep=", ", + target=sql_query, + prefix="(", + suffix=")", + ) + + if _test: + return sql_query + + db_cursor = self._db_cursor() + if seqname is not False: + sql_query = self._process_insert_query(sql_query, tablename, seqname) + + if isinstance(sql_query, tuple): + # for some databases, a separate query has to be made to find + # the id of the inserted row. + q1, q2 = sql_query + self._db_execute(db_cursor, q1) + self._db_execute(db_cursor, q2) + else: + self._db_execute(db_cursor, sql_query) + + try: + out = db_cursor.fetchone()[0] + + # MySQL gives the first id of multiple inserted rows. + # PostgreSQL and SQLite give the last id. + if self.db_module.__name__ in mysql_drivers: + out = range(out, out + len(values)) + else: + out = range(out - len(values) + 1, out + 1) + except Exception: + out = None + + if not self.ctx.transactions: + self.ctx.commit() + return out + + def update(self, tables, where, vars=None, _test=False, **values): + """ + Update `tables` with clause `where` (interpolated using `vars`) + and setting `values`. + + >>> db = DB(None, {}) + >>> name = 'Joseph' + >>> q = db.update('foo', where='name = $name', name='bob', age=2, + ... created=SQLLiteral('NOW()'), vars=locals(), _test=True) + >>> q + + >>> q.query() + 'UPDATE foo SET age = %s, created = NOW(), name = %s WHERE name = %s' + >>> q.values() + [2, 'bob', 'Joseph'] + """ + if vars is None: + vars = {} + + where = self._where(where, vars) + values = sorted(values.items(), key=lambda t: t[0]) + + query = ( + "UPDATE " + + sqllist(tables) + + " SET " + + sqlwhere(values, ", ") + + " WHERE " + + where + ) + + if _test: + return query + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, query) + if not self.ctx.transactions: + self.ctx.commit() + return db_cursor.rowcount + + def delete(self, table, where, using=None, vars=None, _test=False): + """ + Deletes from `table` with clauses `where` and `using`. + + >>> db = DB(None, {}) + >>> name = 'Joe' + >>> db.delete('foo', where='name = $name', vars=locals(), _test=True) + + """ + if vars is None: + vars = {} + + where = self._where(where, vars) + + q = "DELETE FROM " + table + if using: + q += " USING " + sqllist(using) + + if where: + q += " WHERE " + where + + if _test: + return q + + db_cursor = self._db_cursor() + self._db_execute(db_cursor, q) + if not self.ctx.transactions: + self.ctx.commit() + return db_cursor.rowcount + + def _process_insert_query(self, query, tablename, seqname): + return query + + def transaction(self): + """Start a transaction.""" + return Transaction(self.ctx) + + +class PostgresDB(DB): + """Postgres driver.""" + + def __init__(self, **keywords): + if "pw" in keywords: + keywords["password"] = keywords.pop("pw") + + db_module = import_driver(pg_drivers, preferred=keywords.pop("driver", None)) + if db_module.__name__ == "psycopg2": + import psycopg2.extensions + + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) + + # if db is not provided `postgres` driver will take it from PGDATABASE + # environment variable. + if "db" in keywords: + keywords["database"] = keywords.pop("db") + + self.dbname = "postgres" + self.paramstyle = db_module.paramstyle + DB.__init__(self, db_module, keywords) + self.supports_multiple_insert = True + self._sequences = None + + def _process_insert_query(self, query, tablename, seqname): + if seqname is None: + # when seqname is not provided guess the seqname and make sure it exists + seqname = tablename + "_id_seq" + if seqname not in self._get_all_sequences(): + seqname = None + + if seqname: + query += "; SELECT currval('%s')" % seqname + + return query + + def _get_all_sequences(self): + """Query postgres to find names of all sequences used in this database.""" + if self._sequences is None: + q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'" + self._sequences = {c.relname for c in self.query(q)} + return self._sequences + + def _connect(self, keywords): + conn = DB._connect(self, keywords) + conn.set_client_encoding("UTF8") + return conn + + def _connect_with_pooling(self, keywords): + conn = DB._connect_with_pooling(self, keywords) + conn._con._con.set_client_encoding("UTF8") + return conn + + +class MySQLDB(DB): + def __init__(self, **keywords): + + db = import_driver(mysql_drivers, preferred=keywords.pop("driver", None)) + + if db.__name__ == "pymysql": + if "pw" in keywords: + keywords["password"] = keywords["pw"] + del keywords["pw"] + + elif db.__name__ == "MySQLdb": + if "pw" in keywords: + keywords["passwd"] = keywords.pop("pw") + + elif db.__name__ == "mysql.connector": + # Enabled buffered so that len can work as expected. + keywords.setdefault("buffered", True) + + if "pw" in keywords: + keywords["password"] = keywords["pw"] + del keywords["pw"] + + if "charset" not in keywords: + keywords["charset"] = "utf8" + elif keywords["charset"] is None: + del keywords["charset"] + + self.paramstyle = db.paramstyle = "pyformat" # it's both + self.dbname = "mysql" + DB.__init__(self, db, keywords) + self.supports_multiple_insert = True + + def _process_insert_query(self, query, tablename, seqname): + return query, SQLQuery("SELECT last_insert_id();") + + def _get_insert_default_values_query(self, table): + return "INSERT INTO %s () VALUES()" % table + + +def import_driver(drivers, preferred=None): + """Import the first available driver or preferred driver.""" + if preferred: + drivers = (preferred,) + + for d in drivers: + try: + return __import__(d, None, None, ["x"]) + except ImportError: + pass + raise ImportError("Unable to import " + " or ".join(drivers)) + + +class SqliteDB(DB): + def __init__(self, **keywords): + db = import_driver(sqlite_drivers, preferred=keywords.pop("driver", None)) + + if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]: + db.paramstyle = "qmark" + + # sqlite driver doesn't create datatime objects for timestamp columns + # unless `detect_types` option is passed. + # It seems to be supported in `sqlite3` and `pysqlite2` drivers, not + # surte about `sqlite`. + keywords.setdefault("detect_types", db.PARSE_DECLTYPES) + + self.dbname = "sqlite" + self.paramstyle = db.paramstyle + keywords["database"] = keywords.pop("db") + + # sqlite don't allows connections to be shared by threads + keywords["pooling"] = False + + DB.__init__(self, db, keywords) + + def _process_insert_query(self, query, tablename, seqname): + return query, SQLQuery("SELECT last_insert_rowid();") + + def create_result_set(self, cursor): + return SqliteResultSet(cursor) + + +class FirebirdDB(DB): + """Firebird Database.""" + + def __init__(self, **keywords): + try: + import kinterbasdb as db + except Exception: + db = None + pass + if "pw" in keywords: + keywords["password"] = keywords.pop("pw") + keywords["database"] = keywords.pop("db") + + self.paramstyle = db.paramstyle + + DB.__init__(self, db, keywords) + + def delete(self, table, where=None, using=None, vars=None, _test=False): + # firebird doesn't support using clause + using = None + return DB.delete(self, table, where, using, vars, _test) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ("SELECT", ""), + ("FIRST", limit), + ("SKIP", offset), + ("", what), + ("FROM", sqllist(tables)), + ("WHERE", where), + ("GROUP BY", group), + ("ORDER BY", order), + ) + + +class MSSQLDB(DB): + def __init__(self, **keywords): + import pymssql as db + + if "pw" in keywords: + keywords["password"] = keywords.pop("pw") + keywords["database"] = keywords.pop("db") + self.dbname = "mssql" + DB.__init__(self, db, keywords) + + def _process_query(self, sql_query): + """Takes the SQLQuery object and returns query string and parameters.""" + # MSSQLDB expects params to be a tuple. + # Overwriting the default implementation to convert params to tuple. + paramstyle = getattr(self, "paramstyle", "pyformat") + query = sql_query.query(paramstyle) + params = sql_query.values() + return query, tuple(params) + + def sql_clauses(self, what, tables, where, group, order, limit, offset): + return ( + ("SELECT", what), + ("TOP", limit), + ("FROM", sqllist(tables)), + ("WHERE", where), + ("GROUP BY", group), + ("ORDER BY", order), + ("OFFSET", offset), + ) + + def _test(self): + """Test LIMIT. + + Fake presence of pymssql module for running tests. + >>> import sys + >>> sys.modules['pymssql'] = sys.modules['sys'] + + MSSQL has TOP clause instead of LIMIT clause. + >>> db = MSSQLDB(db='test', user='joe', pw='secret') + >>> db.select('foo', limit=4, _test=True) + + """ + pass + + +class OracleDB(DB): + def __init__(self, **keywords): + import cx_Oracle as db + + if "pw" in keywords: + keywords["password"] = keywords.pop("pw") + + # @@ TODO: use db.makedsn if host, port is specified + keywords["dsn"] = keywords.pop("db") + self.dbname = "oracle" + db.paramstyle = "numeric" + self.paramstyle = db.paramstyle + + # oracle doesn't support pooling + keywords.pop("pooling", None) + DB.__init__(self, db, keywords) + + def _process_insert_query(self, query, tablename, seqname): + if seqname is None: + # It is not possible to get seq name from table name in Oracle + return query + else: + return query + "; SELECT %s.currval FROM dual" % seqname + + +def dburl2dict(url): + """ + Takes a URL to a database and parses it into an equivalent dictionary. + + >>> dburl2dict('postgres:///mygreatdb') == {'pw': None, 'dbn': 'postgres', 'db': 'mygreatdb', 'host': None, 'user': None, 'port': None} + True + >>> dburl2dict('postgres://james:day@serverfarm.example.net:5432/mygreatdb') == {'pw': 'day', 'dbn': 'postgres', 'db': 'mygreatdb', 'host': 'serverfarm.example.net', 'user': 'james', 'port': 5432} + True + >>> dburl2dict('postgres://james:day@serverfarm.example.net/mygreatdb') == {'pw': 'day', 'dbn': 'postgres', 'db': 'mygreatdb', 'host': 'serverfarm.example.net', 'user': 'james', 'port': None} + True + >>> dburl2dict('postgres://james:d%40y@serverfarm.example.net/mygreatdb') == {'pw': 'd@y', 'dbn': 'postgres', 'db': 'mygreatdb', 'host': 'serverfarm.example.net', 'user': 'james', 'port': None} + True + >>> dburl2dict('mysql://james:d%40y@serverfarm.example.net/mygreatdb') == {'pw': 'd@y', 'dbn': 'mysql', 'db': 'mygreatdb', 'host': 'serverfarm.example.net', 'user': 'james', 'port': None} + True + >>> dburl2dict('sqlite:///mygreatdb.db') + {'db': 'mygreatdb.db', 'dbn': 'sqlite'} + >>> dburl2dict('sqlite:////absolute/path/mygreatdb.db') + {'db': '/absolute/path/mygreatdb.db', 'dbn': 'sqlite'} + """ + parts = urlparse.urlparse(unquote(url)) + + if parts.scheme == "sqlite": + return {"dbn": parts.scheme, "db": parts.path[1:]} + else: + return { + "dbn": parts.scheme, + "user": parts.username, + "pw": parts.password, + "db": parts.path[1:], + "host": parts.hostname, + "port": parts.port, + } + + +_databases = {} + + +def database(dburl=None, **params): + """Creates appropriate database using params. + + Pooling will be enabled if DBUtils module is available. + Pooling can be disabled by passing pooling=False in params. + """ + if not dburl and not params: + dburl = os.environ["DATABASE_URL"] + + if dburl: + params = dburl2dict(dburl) + + dbn = params.pop("dbn") + if dbn in _databases: + return _databases[dbn](**params) + else: + raise UnknownDB(dbn) + + +def register_database(name, clazz): + """ + Register a database. + + >>> class LegacyDB(DB): + ... def __init__(self, **params): + ... pass + ... + >>> register_database('legacy', LegacyDB) + >>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') + """ + _databases[name] = clazz + + +register_database("mysql", MySQLDB) +register_database("postgres", PostgresDB) +register_database("sqlite", SqliteDB) +register_database("firebird", FirebirdDB) +register_database("mssql", MSSQLDB) +register_database("oracle", OracleDB) + + +def _interpolate(format): + """ + Takes a format string and returns a list of 2-tuples of the form + (boolean, string) where boolean says whether string should be evaled + or not. + + from (public domain, Ka-Ping Yee) + """ + + def matchorfail(text, pos): + match = tokenprog.match(text, pos) + if match is None: + raise _ItplError(text, pos) + return match, match.end() + + namechars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + chunks = [] + pos = 0 + + while 1: + dollar = format.find("$", pos) + if dollar < 0: + break + nextchar = format[dollar + 1] + + if nextchar == "{": + chunks.append((0, format[pos:dollar])) + pos, level = dollar + 2, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token == "{": + level = level + 1 + elif token == "}": + level = level - 1 + chunks.append((1, format[dollar + 2 : pos - 1])) + + elif nextchar in namechars: + chunks.append((0, format[pos:dollar])) + match, pos = matchorfail(format, dollar + 1) + while pos < len(format): + if ( + format[pos] == "." + and pos + 1 < len(format) + and format[pos + 1] in namechars + ): + match, pos = matchorfail(format, pos + 1) + elif format[pos] in "([": + pos, level = pos + 1, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token[0] in "([": + level = level + 1 + elif token[0] in ")]": + level = level - 1 + else: + break + chunks.append((1, format[dollar + 1 : pos])) + else: + chunks.append((0, format[pos : dollar + 1])) + pos = dollar + 1 + (nextchar == "$") + + if pos < len(format): + chunks.append((0, format[pos:])) + return chunks + + +class _Node: + def __init__(self, type, first, second=None): + self.type = type + self.first = first + self.second = second + + def __eq__(self, other): + return ( + isinstance(other, _Node) + and self.type == other.type + and self.first == other.first + and self.second == other.second + ) + + def __repr__(self): + return f"Node({self.type!r}, {self.first!r}, {self.second!r})" + + +class Parser: + """Parser to parse string templates like "Hello $name". + + Loosely based on (public domain, Ka-Ping Yee) + """ + + namechars = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_" + + def __init__(self): + self.reset() + + def reset(self): + self.pos = 0 + self.level = 0 + self.text = "" + + def parse(self, text): + """Parses the given text and returns a parse tree.""" + self.reset() + self.text = text + return self.parse_all() + + def parse_all(self): + while True: + dollar = self.text.find("$", self.pos) + if dollar < 0: + break + nextchar = self.text[dollar + 1] + if nextchar in self.namechars: + yield _Node("text", self.text[self.pos : dollar]) + self.pos = dollar + 1 + yield self.parse_expr() + + # for supporting ${x.id}, for backward compatibility + elif nextchar == "{": + saved_pos = self.pos + self.pos = dollar + 2 # skip "${" + expr = self.parse_expr() + if self.text[self.pos] == "}": + self.pos += 1 + yield _Node("text", self.text[self.pos : dollar]) + yield expr + else: + self.pos = saved_pos + break + else: + yield _Node("text", self.text[self.pos : dollar + 1]) + self.pos = dollar + 1 + # $$ is used to escape $ + if nextchar == "$": + self.pos += 1 + + if self.pos < len(self.text): + yield _Node("text", self.text[self.pos :]) + + def match(self): + match = tokenprog.match(self.text, self.pos) + if match is None: + raise _ItplError(self.text, self.pos) + return match, match.end() + + def is_literal(self, text): + return text and text[0] in "0123456789\"'" + + def parse_expr(self): + match, pos = self.match() + if self.is_literal(match.group()): + expr = _Node("literal", match.group()) + else: + expr = _Node("param", self.text[self.pos : pos]) + self.pos = pos + while self.pos < len(self.text): + if ( + self.text[self.pos] == "." + and self.pos + 1 < len(self.text) + and self.text[self.pos + 1] in self.namechars + ): + self.pos += 1 + match, pos = self.match() + attr = match.group() + expr = _Node("getattr", expr, attr) + self.pos = pos + elif self.text[self.pos] == "[": + saved_pos = self.pos + self.pos += 1 + key = self.parse_expr() + if self.text[self.pos] == "]": + self.pos += 1 + expr = _Node("getitem", expr, key) + else: + self.pos = saved_pos + break + else: + break + return expr + + +class SafeEval: + """Safe evaluator for binding params to db queries.""" + + def safeeval(self, text, mapping): + nodes = Parser().parse(text) + return SQLQuery.join([self.eval_node(node, mapping) for node in nodes], "") + + def eval_node(self, node, mapping): + if node.type == "text": + return node.first + else: + return sqlquote(self.eval_expr(node, mapping)) + + def eval_expr(self, node, mapping): + if node.type == "literal": + return ast.literal_eval(node.first) + elif node.type == "getattr": + return getattr(self.eval_expr(node.first, mapping), node.second) + elif node.type == "getitem": + return self.eval_expr(node.first, mapping)[ + self.eval_expr(node.second, mapping) + ] + elif node.type == "param": + return mapping[node.first] + + +def test_parser(): + def f(text, expected): + p = Parser() + nodes = list(p.parse(text)) + print(repr(text), nodes) + assert nodes == expected, "Expected %r" % expected + + f("Hello", [_Node("text", "Hello")]) + f("Hello $name", [_Node("text", "Hello "), _Node("param", "name")]) + f( + "Hello $name.foo", + [_Node("text", "Hello "), _Node("getattr", _Node("param", "name"), "foo")], + ) + f( + "WHERE id=$self.id LIMIT 1", + [ + _Node("text", "WHERE id="), + _Node("getattr", _Node("param", "self", None), "id"), + _Node("text", " LIMIT 1"), + ], + ) + + f( + "WHERE id=$self['id'] LIMIT 1", + [ + _Node("text", "WHERE id="), + _Node("getitem", _Node("param", "self", None), _Node("literal", "'id'")), + _Node("text", " LIMIT 1"), + ], + ) + + +def test_safeeval(): + def f(q, vars): + return SafeEval().safeeval(q, vars) + + print(f("WHERE id=$id", {"id": 1}).items) + assert f("WHERE id=$id", {"id": 1}).items == ["WHERE id=", sqlparam(1)] + + +if __name__ == "__main__": + import doctest + + doctest.testmod() + test_parser() + test_safeeval() diff --git a/web/debugerror.py b/web/debugerror.py new file mode 100644 index 0000000..7bf992e --- /dev/null +++ b/web/debugerror.py @@ -0,0 +1,377 @@ +""" +pretty debug errors +(part of web.py) + +portions adapted from Django +Copyright (c) 2005, the Lawrence Journal-World +Used under the modified BSD license: +http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 +""" + +__all__ = ["debugerror", "djangoerror", "emailerrors"] + +import os +import os.path +import pprint +import sys +import traceback + +from . import webapi as web +from .net import websafe +from .template import Template +from .utils import safestr, sendmail + + +def update_globals_template(t, globals): + t.t.__globals__.update(globals) + + +whereami = os.path.join(os.getcwd(), __file__) +whereami = os.path.sep.join(whereami.split(os.path.sep)[:-1]) +djangoerror_t = """\ +$def with (exception_type, exception_value, frames) + + + + + + $exception_type at $ctx.path + + + + + +$def dicttable (d, kls='req', id=None): + $ items = d and list(d.items()) or [] + $items.sort() + $:dicttable_items(items, kls, id) + +$def dicttable_items(items, kls='req', id=None): + $if items: + + + $for k, v in items: + + +
VariableValue
$k
$prettify(v)
+ $else: +

No data.

+ +
+

$exception_type at $ctx.path

+

$exception_value

+ + + + + + +
Python$frames[0].filename in $frames[0].function, line $frames[0].lineno
Web$ctx.method $ctx.home$ctx.path
+
+
+

Traceback (innermost first)

+
    +$for frame in frames: +
  • + $frame.filename in $frame.function + $if frame.context_line is not None: +
    + $if frame.pre_context: +
      + $for line in frame.pre_context: +
    1. $line
    2. +
    +
    1. $frame.context_line ...
    + $if frame.post_context: +
      + $for line in frame.post_context: +
    1. $line
    2. +
    +
    + + $if frame.vars: +
    + Local vars + $# $inspect.formatargvalues(*inspect.getargvalues(frame['tb'].tb_frame)) +
    + $:dicttable(frame.vars, kls='vars', id=('v' + str(frame.id))) +
  • +
+
+ +
+$if ctx.output or ctx.headers: +

Response so far

+

HEADERS

+ $:dicttable_items(ctx.headers) + +

BODY

+

+ $ctx.output +

+ +

Request information

+ +

INPUT

+$:dicttable(web.input(_unicode=False)) + + +$:dicttable(web.cookies()) + +

META

+$ newctx = [(k, v) for (k, v) in ctx.iteritems() if not k.startswith('_') and not isinstance(v, dict)] +$:dicttable(dict(newctx)) + +

ENVIRONMENT

+$:dicttable(ctx.env) +
+ +
+

+ You're seeing this error because you have web.config.debug + set to True. Set that to False if you don't want to see this. +

+
+ + + +""" # noqa: W605 + +djangoerror_r = None + + +def djangoerror(): + def _get_lines_from_file(filename, lineno, context_lines): + """ + Returns context_lines before and after lineno from file. + Returns (pre_context_lineno, pre_context, context_line, post_context). + """ + try: + source = open(filename).readlines() + lower_bound = max(0, lineno - context_lines) + upper_bound = lineno + context_lines + + pre_context = [line.strip("\n") for line in source[lower_bound:lineno]] + context_line = source[lineno].strip("\n") + post_context = [ + line.strip("\n") for line in source[lineno + 1 : upper_bound] + ] + + return lower_bound, pre_context, context_line, post_context + except (OSError, IndexError): + return None, [], None, [] + + exception_type, exception_value, tback = sys.exc_info() + frames = [] + while tback is not None: + filename = tback.tb_frame.f_code.co_filename + function = tback.tb_frame.f_code.co_name + lineno = tback.tb_lineno - 1 + + # hack to get correct line number for templates + lineno += tback.tb_frame.f_locals.get("__lineoffset__", 0) + + ( + pre_context_lineno, + pre_context, + context_line, + post_context, + ) = _get_lines_from_file(filename, lineno, 7) + + if "__hidetraceback__" not in tback.tb_frame.f_locals: + frames.append( + web.storage( + { + "tback": tback, + "filename": filename, + "function": function, + "lineno": lineno, + "vars": tback.tb_frame.f_locals, + "id": id(tback), + "pre_context": pre_context, + "context_line": context_line, + "post_context": post_context, + "pre_context_lineno": pre_context_lineno, + } + ) + ) + tback = tback.tb_next + frames.reverse() + + def prettify(x): + try: + out = pprint.pformat(x) + except Exception as e: + out = "[could not display: <" + e.__class__.__name__ + ": " + str(e) + ">]" + return out + + global djangoerror_r + if djangoerror_r is None: + djangoerror_r = Template(djangoerror_t, filename=__file__, filter=websafe) + + t = djangoerror_r + globals = { + "ctx": web.ctx, + "web": web, + "dict": dict, + "str": str, + "prettify": prettify, + } + update_globals_template(t, globals) + return t(exception_type, exception_value, frames) + + +def debugerror(): + """ + A replacement for `internalerror` that presents a nice page with lots + of debug information for the programmer. + + (Based on the beautiful 500 page from [Django](http://djangoproject.com/), + designed by [Wilson Miner](http://wilsonminer.com/).) + """ + return web._InternalError(djangoerror()) + + +def emailerrors(to_address, olderror, from_address=None): + """ + Wraps the old `internalerror` handler (pass as `olderror`) to + additionally email all errors to `to_address`, to aid in + debugging production websites. + + Emails contain a normal text traceback as well as an + attachment containing the nice `debugerror` page. + """ + from_address = from_address or to_address + + def emailerrors_internal(): + error = olderror() + tb = sys.exc_info() + error_name = tb[0] + error_value = tb[1] + tb_txt = "".join(traceback.format_exception(*tb)) + path = web.ctx.path + request = web.ctx.method + " " + web.ctx.home + web.ctx.fullpath + + message = f"\n{request}\n\n{tb_txt}\n\n" + + sendmail( + "your buggy site <%s>" % from_address, + "the bugfixer <%s>" % to_address, + "bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(), + message, + attachments=[dict(filename="bug.html", content=safestr(djangoerror()))], + ) + return error + + return emailerrors_internal + + +if __name__ == "__main__": + urls = ("/", "index") + from .application import application + + app = application(urls, globals()) + app.internalerror = debugerror + + class index: + def GET(self): + thisdoesnotexist # noqa: F821 + + app.run() diff --git a/web/form.py b/web/form.py new file mode 100644 index 0000000..3339c77 --- /dev/null +++ b/web/form.py @@ -0,0 +1,690 @@ +""" +HTML forms +(part of web.py) +""" + +import copy +import re + +from . import net, utils +from . import webapi as web + + +def attrget(obj, attr, value=None): + try: + if hasattr(obj, "has_key") and attr in obj: + return obj[attr] + except TypeError: + # Handle the case where has_key takes different number of arguments. + # This is the case with Model objects on appengine. See #134 + pass + if ( + hasattr(obj, "keys") and attr in obj + ): # needed for Py3, has_key doesn't exist anymore + return obj[attr] + elif hasattr(obj, attr): + return getattr(obj, attr) + return value + + +class Form: + r""" + HTML form. + + >>> f = Form(Textbox("x")) + >>> f.render() + u'\n \n
' + >>> f.fill(x="42") + True + >>> f.render() + u'\n \n
' + """ + + def __init__(self, *inputs, **kw): + self.inputs = inputs + self.valid = True + self.note = None + self.validators = kw.pop("validators", []) + + def __call__(self, x=None): + o = copy.deepcopy(self) + if x: + o.validates(x) + return o + + def render(self): + out = "" + out += self.rendernote(self.note) + out += "\n" + + for i in self.inputs: + html = ( + utils.safeunicode(i.pre) + + i.render() + + self.rendernote(i.note) + + utils.safeunicode(i.post) + ) + if i.is_hidden(): + out += ' \n' % ( + html + ) + else: + out += ( + ' \n' + % (net.websafe(i.id), net.websafe(i.description), html) + ) + out += "
%s
%s
" + return out + + def render_css(self): + out = [] + out.append(self.rendernote(self.note)) + for i in self.inputs: + if not i.is_hidden(): + out.append( + '' + % (net.websafe(i.id), net.websafe(i.description)) + ) + out.append(i.pre) + out.append(i.render()) + out.append(self.rendernote(i.note)) + out.append(i.post) + out.append("\n") + return "".join(out) + + def rendernote(self, note): + if note: + return '%s' % net.websafe(note) + else: + return "" + + def validates(self, source=None, _validate=True, **kw): + source = source or kw or web.input() + out = True + for i in self.inputs: + v = attrget(source, i.name) + if _validate: + out = i.validate(v) and out + else: + i.set_value(v) + if _validate: + out = out and self._validate(source) + self.valid = out + return out + + def _validate(self, value): + self.value = value + for v in self.validators: + if not v.valid(value): + self.note = v.msg + return False + return True + + def fill(self, source=None, **kw): + return self.validates(source, _validate=False, **kw) + + def __getitem__(self, i): + for x in self.inputs: + if x.name == i: + return x + raise KeyError(i) + + def __getattr__(self, name): + # don't interfere with deepcopy + inputs = self.__dict__.get("inputs") or [] + for x in inputs: + if x.name == name: + return x + raise AttributeError(name) + + def get(self, i, default=None): + try: + return self[i] + except KeyError: + return default + + def _get_d(self): # @@ should really be form.attr, no? + return utils.storage([(i.name, i.get_value()) for i in self.inputs]) + + d = property(_get_d) + + +class Input: + """Generic input. Type attribute must be specified when called directly. + + See also: + + Currently only types which can be written inside one `` tag are + supported. + + - For checkbox, please use `Checkbox` class for better control. + - For radiobox, please use `Radio` class for better control. + + >>> Input(name='foo', type='email', value="user@domain.com").render() + u'' + >>> Input(name='foo', type='number', value="bar").render() + u'' + >>> Input(name='num', type="number", min='0', max='10', step='2', value='5').render() + u'' + >>> Input(name='foo', type="tel", value='55512345').render() + u'' + >>> Input(name='search', type="search", value='Search').render() + u'' + >>> Input(name='search', type="search", value='Search', required='required', pattern='[a-z0-9]{2,30}', placeholder='Search...').render() + u'' + >>> Input(name='url', type="url", value='url').render() + u'' + >>> Input(name='range', type="range", min='0', max='10', step='2', value='5').render() + u'' + >>> Input(name='color', type="color").render() + u'' + >>> Input(name='f', type="file", accept=".doc,.docx,.xml").render() + u'' + """ + + def __init__(self, name, *validators, **attrs): + self.name = name + self.validators = validators + self.attrs = attrs = AttributeList(attrs) + + self.type = attrs.pop("type", None) + self.description = attrs.pop("description", name) + self.value = attrs.pop("value", None) + self.pre = attrs.pop("pre", "") + self.post = attrs.pop("post", "") + self.note = None + + self.id = attrs.setdefault("id", self.get_default_id()) + + if "class_" in attrs: + attrs["class"] = attrs["class_"] + del attrs["class_"] + + def is_hidden(self): + return False + + def get_type(self): + if self.type is not None: + return self.type + else: + raise AttributeError("missing attribute 'type'") + + def get_default_id(self): + return self.name + + def validate(self, value): + self.set_value(value) + + for v in self.validators: + if not v.valid(value): + self.note = v.msg + return False + return True + + def set_value(self, value): + self.value = value + + def get_value(self): + return self.value + + def render(self): + attrs = self.attrs.copy() + attrs["type"] = self.get_type() + if self.value is not None: + attrs["value"] = self.value + attrs["name"] = self.name + attrs["id"] = self.id + return "" % attrs + + def rendernote(self, note): + if note: + return '%s' % net.websafe(note) + else: + return "" + + def addatts(self): + # add leading space for backward-compatibility + return " " + str(self.attrs) + + +class AttributeList(dict): + """List of attributes of input. + + >>> a = AttributeList(type='text', name='x', value=20) + >>> a + + """ + + def copy(self): + return AttributeList(self) + + def __str__(self): + return " ".join([f'{k}="{net.websafe(v)}"' for k, v in sorted(self.items())]) + + def __repr__(self): + return "" % repr(str(self)) + + +class Textbox(Input): + """Textbox input. + + >>> Textbox(name='foo', value='bar').render() + u'' + >>> Textbox(name='foo', value=0).render() + u'' + """ + + def get_type(self): + return "text" + + +class Password(Input): + """Password input. + + >>> Password(name='password', value='secret').render() + u'' + """ + + def get_type(self): + return "password" + + +class Textarea(Input): + """Textarea input. + + >>> Textarea(name='foo', value='bar').render() + u'' + """ + + def render(self): + attrs = self.attrs.copy() + attrs["name"] = self.name + value = net.websafe(self.value or "") + return f"" + + +class Dropdown(Input): + r"""Dropdown/select input. + + >>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render() + u'\n' + >>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render() + u'\n' + """ + + def __init__(self, name, args, *validators, **attrs): + self.args = args + super().__init__(name, *validators, **attrs) + + def render(self): + attrs = self.attrs.copy() + attrs["name"] = self.name + + x = "\n" + return x + + def _render_option(self, arg, indent=" "): + if isinstance(arg, (tuple, list)): + value, desc = arg + else: + value, desc = arg, arg + + value = utils.safestr(value) + if isinstance(self.value, (tuple, list)): + s_value = [utils.safestr(x) for x in self.value] + else: + s_value = utils.safestr(self.value) + + if s_value == value or (isinstance(s_value, list) and value in s_value): + select_p = ' selected="selected"' + else: + select_p = "" + return indent + '{}\n'.format( + select_p, + net.websafe(value), + net.websafe(desc), + ) + + +class GroupedDropdown(Dropdown): + r"""Grouped Dropdown/select input. + + >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', ('Volvo', 'Saab')), ('German Cars', ('Mercedes', 'Audi'))), value='Audi').render() + u'\n' + >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', (('v', 'Volvo'), ('s', 'Saab'))), ('German Cars', (('m', 'Mercedes'), ('a', 'Audi')))), value='a').render() + u'\n' + + """ + + def __init__(self, name, args, *validators, **attrs): + self.args = args + super().__init__(name, *validators, **attrs) + + def render(self): + attrs = self.attrs.copy() + attrs["name"] = self.name + + x = "\n" + return x + + +class Radio(Input): + def __init__(self, name, args, *validators, **attrs): + self.args = args + super().__init__(name, *validators, **attrs) + + def render(self): + x = "" + for idx, arg in enumerate(self.args, start=1): + if isinstance(arg, (tuple, list)): + value, desc = arg + else: + value, desc = arg, arg + attrs = self.attrs.copy() + attrs["name"] = self.name + attrs["type"] = "radio" + attrs["value"] = value + attrs["id"] = self.name + str(idx) + if self.value == value: + attrs["checked"] = "checked" + x += f" {net.websafe(desc)}" + x += "" + return x + + +class Checkbox(Input): + """Checkbox input. + + >>> Checkbox('foo', value='bar', checked=True).render() + u'' + >>> Checkbox('foo', value='bar').render() + u'' + >>> c = Checkbox('foo', value='bar') + >>> c.validate('on') + True + >>> c.render() + u'' + """ + + def __init__(self, name, *validators, **attrs): + self.checked = attrs.pop("checked", False) + Input.__init__(self, name, *validators, **attrs) + + def get_default_id(self): + value = utils.safestr(self.value or "") + return self.name + "_" + value.replace(" ", "_") + + def render(self): + attrs = self.attrs.copy() + attrs["type"] = "checkbox" + attrs["name"] = self.name + attrs["value"] = self.value + + if self.checked: + attrs["checked"] = "checked" + return "" % attrs + + def set_value(self, value): + self.checked = bool(value) + + def get_value(self): + return self.checked + + +class Button(Input): + """HTML Button. + + >>> Button("save").render() + u'' + >>> Button("action", value="save", html="Save Changes").render() + u'' + """ + + def __init__(self, name, *validators, **attrs): + super().__init__(name, *validators, **attrs) + self.description = "" + + def render(self): + attrs = self.attrs.copy() + attrs["name"] = self.name + if self.value is not None: + attrs["value"] = self.value + html = attrs.pop("html", None) or net.websafe(self.name) + return f"" + + +class Hidden(Input): + """Hidden Input. + + >>> Hidden(name='foo', value='bar').render() + u'' + """ + + def is_hidden(self): + return True + + def get_type(self): + return "hidden" + + +class File(Input): + """File input. + + >>> File(name='f', accept=".doc,.docx,.xml").render() + u'' + """ + + def get_type(self): + return "file" + + +class Telephone(Input): + """Telephone input. + + See: + + >>> Telephone(name='tel', value='55512345').render() + u'' + """ + + def get_type(self): + return "tel" + + +class Email(Input): + """Email input. + + See: + + >>> Email(name='email', value='me@example.org').render() + u'' + + """ + + def get_type(self): + return "email" + + +class Date(Input): + """Date input. + + Note: Not supported by desktop Safari, Internet Explorer, or Opera Mini + + See: + + >>> Date(name='date', value='2020-04-01').render() + u'' + + """ + + def get_type(self): + return "date" + + +class Time(Input): + """Time input. + + Note: Not supported by desktop Safari, Internet Explorer, or Opera Mini + + See: + + >>> Time(name='time', value='07:00').render() + u'' + + """ + + def get_type(self): + return "time" + + +class Search(Input): + """Search input. + + See: + + >> Search(name='search', value='Search').render() + u'' + >>> Search(name='search', value='Search', required='required', pattern='[a-z0-9]{2,30}', placeholder='Search...').render() + u'' + + """ + + def get_type(self): + return "search" + + +class Url(Input): + """URL input. + + See: + + >>> Url(name='url', value='url').render() + u'' + """ + + def get_type(self): + return "url" + + +class Number(Input): + """Number input. + + See: + + >>> Number(name='num', min='0', max='10', step='2', value='5').render() + u'' + """ + + def get_type(self): + return "number" + + +class Range(Input): + """Range input. + + See: + + >>> Range(name='range', min='0', max='10', step='2', value='5').render() + u'' + """ + + def get_type(self): + return "range" + + +class Color(Input): + """Color input. + + Note: Not supported by Internet Explorer or Opera Mini + + See: + + >>> Color(name='color').render() + u'' + """ + + def get_type(self): + return "color" + + +class Datalist(Input): + """Datalist input. + + This is currently supported by Chrome, Firefox, Edge, and Opera. It is not + supported on Safari or Internet Explorer. Use it with caution. + + Datalist cannot be used separately. It must be bound to an input. + + + + >>> Datalist(name='list', args=[('a', 'b'), ('c', 'd')]).render() + u'' + >>> Datalist(name='list', args=[['a', 'b'], ['c', 'd']]).render() + u'' + >>> Datalist(name='list', args=['a', 'b', 'c', 'd']).render() + u'' + """ + + def __init__(self, name, args, *validators, **kwargs): + self.args = args + super().__init__(name, *validators, **kwargs) + + def render(self): + attrs = self.attrs.copy() + attrs["name"] = self.name + label_p = "" + x = "" % attrs + for arg in self.args: + if isinstance(arg, (tuple, list)): + label_p = ' label="%s"' % net.websafe(arg[0]) + label = net.websafe(arg[1]) + else: + label = net.websafe(arg) + x += f'' + x += "" + return x + + +class Validator: + def __deepcopy__(self, memo): + return copy.copy(self) + + def __init__(self, msg, test, jstest=None): + utils.autoassign(self, locals()) + + def valid(self, value): + try: + return self.test(value) + except: + return False + + +notnull = Validator("Required", bool) + + +class regexp(Validator): + def __init__(self, rexp, msg): + self.rexp = re.compile(rexp) + self.msg = msg + + def valid(self, value): + return bool(self.rexp.match(value)) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/web/http.py b/web/http.py new file mode 100644 index 0000000..8799001 --- /dev/null +++ b/web/http.py @@ -0,0 +1,168 @@ +""" +HTTP Utilities +(from web.py) +""" + +__all__ = [ + "expires", + "lastmodified", + "prefixurl", + "modified", + "changequery", + "url", + "profiler", +] + +import datetime + +from . import net, utils +from . import webapi as web +from .py3helpers import iteritems + +try: + from urllib.parse import urlencode as urllib_urlencode +except ImportError: + from urllib import urlencode as urllib_urlencode + + +def prefixurl(base=""): + """ + Sorry, this function is really difficult to explain. + Maybe some other time. + """ + url = web.ctx.path.lstrip("/") + for i in range(url.count("/")): + base += "../" + if not base: + base = "./" + return base + + +def expires(delta): + """ + Outputs an `Expires` header for `delta` from now. + `delta` is a `timedelta` object or a number of seconds. + """ + if isinstance(delta, int): + delta = datetime.timedelta(seconds=delta) + date_obj = datetime.datetime.utcnow() + delta + web.header("Expires", net.httpdate(date_obj)) + + +def lastmodified(date_obj): + """Outputs a `Last-Modified` header for `datetime`.""" + web.header("Last-Modified", net.httpdate(date_obj)) + + +def modified(date=None, etag=None): + """ + Checks to see if the page has been modified since the version in the + requester's cache. + + When you publish pages, you can include `Last-Modified` and `ETag` + with the date the page was last modified and an opaque token for + the particular version, respectively. When readers reload the page, + the browser sends along the modification date and etag value for + the version it has in its cache. If the page hasn't changed, + the server can just return `304 Not Modified` and not have to + send the whole page again. + + This function takes the last-modified date `date` and the ETag `etag` + and checks the headers to see if they match. If they do, it returns + `True`, or otherwise it raises NotModified error. It also sets + `Last-Modified` and `ETag` output headers. + """ + n = {x.strip('" ') for x in web.ctx.env.get("HTTP_IF_NONE_MATCH", "").split(",")} + m = net.parsehttpdate(web.ctx.env.get("HTTP_IF_MODIFIED_SINCE", "").split(";")[0]) + validate = False + if etag: + if "*" in n or etag in n: + validate = True + if date and m: + # we subtract a second because + # HTTP dates don't have sub-second precision + if date - datetime.timedelta(seconds=1) <= m: + validate = True + + if date: + lastmodified(date) + if etag: + web.header("ETag", '"' + etag + '"') + if validate: + raise web.notmodified() + else: + return True + + +def urlencode(query, doseq=0): + """ + Same as urllib.urlencode, but supports unicode strings. + + >>> urlencode({'text':'foo bar'}) + 'text=foo+bar' + >>> urlencode({'x': [1, 2]}, doseq=True) + 'x=1&x=2' + """ + + def convert(value, doseq=False): + if doseq and isinstance(value, list): + return [convert(v) for v in value] + else: + return utils.safestr(value) + + query = {k: convert(v, doseq) for k, v in query.items()} + return urllib_urlencode(query, doseq=doseq) + + +def changequery(query=None, **kw): + """ + Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return + `/foo?a=3&b=2` -- the same URL but with the arguments you requested + changed. + """ + if query is None: + query = web.rawinput(method="get") + for k, v in iteritems(kw): + if v is None: + query.pop(k, None) + else: + query[k] = v + out = web.ctx.path + if query: + out += "?" + urlencode(query, doseq=True) + return out + + +def url(path=None, doseq=False, **kw): + """ + Makes url by concatenating web.ctx.homepath and path and the + query string created using the arguments. + """ + if path is None: + path = web.ctx.path + if path.startswith("/"): + out = web.ctx.homepath + path + else: + out = path + + if kw: + out += "?" + urlencode(kw, doseq=doseq) + + return out + + +def profiler(app): + """Outputs basic profiling information at the bottom of each response.""" + from utils import profile + + def profile_internal(e, o): + out, result = profile(app)(e, o) + return list(out) + ["
" + net.websafe(result) + "
"] + + return profile_internal + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/web/httpserver.py b/web/httpserver.py new file mode 100644 index 0000000..43f6a24 --- /dev/null +++ b/web/httpserver.py @@ -0,0 +1,306 @@ +import os +import posixpath +import sys +from http.server import BaseHTTPRequestHandler, HTTPServer, SimpleHTTPRequestHandler +from io import BytesIO +from urllib import parse as urlparse +from urllib.parse import unquote + +from . import utils +from . import webapi as web + +__all__ = ["runsimple"] + + +def runbasic(func, server_address=("0.0.0.0", 8080)): + """ + Runs a simple HTTP server hosting WSGI app `func`. The directory `static/` + is hosted statically. + + Based on [WsgiServer][ws] from [Colin Stewart][cs]. + + [ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html + [cs]: http://www.owlfish.com/ + """ + # Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/) + # Modified somewhat for simplicity + # Used under the modified BSD license: + # http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 + + import errno + import socket + import traceback + + import SocketServer + + class WSGIHandler(SimpleHTTPRequestHandler): + def run_wsgi_app(self): + protocol, host, path, parameters, query, fragment = urlparse.urlparse( + "http://dummyhost%s" % self.path + ) + + # we only use path, query + env = { + "wsgi.version": (1, 0), + "wsgi.url_scheme": "http", + "wsgi.input": self.rfile, + "wsgi.errors": sys.stderr, + "wsgi.multithread": 1, + "wsgi.multiprocess": 0, + "wsgi.run_once": 0, + "REQUEST_METHOD": self.command, + "REQUEST_URI": self.path, + "PATH_INFO": path, + "QUERY_STRING": query, + "CONTENT_TYPE": self.headers.get("Content-Type", ""), + "CONTENT_LENGTH": self.headers.get("Content-Length", ""), + "REMOTE_ADDR": self.client_address[0], + "SERVER_NAME": self.server.server_address[0], + "SERVER_PORT": str(self.server.server_address[1]), + "SERVER_PROTOCOL": self.request_version, + } + + for http_header, http_value in self.headers.items(): + env["HTTP_%s" % http_header.replace("-", "_").upper()] = http_value + + # Setup the state + self.wsgi_sent_headers = 0 + self.wsgi_headers = [] + + try: + # We have there environment, now invoke the application + result = self.server.app(env, self.wsgi_start_response) + try: + try: + for data in result: + if data: + self.wsgi_write_data(data) + finally: + if hasattr(result, "close"): + result.close() + except OSError as socket_err: + # Catch common network errors and suppress them + if socket_err.args[0] in (errno.ECONNABORTED, errno.EPIPE): + return + except socket.timeout: + return + except: + print(traceback.format_exc(), file=web.debug) + + if not self.wsgi_sent_headers: + # We must write out something! + self.wsgi_write_data(" ") + return + + do_POST = run_wsgi_app + do_PUT = run_wsgi_app + do_DELETE = run_wsgi_app + + def do_GET(self): + if self.path.startswith("/static/"): + SimpleHTTPRequestHandler.do_GET(self) + else: + self.run_wsgi_app() + + def wsgi_start_response(self, response_status, response_headers, exc_info=None): + if self.wsgi_sent_headers: + raise Exception("Headers already sent and start_response called again!") + # Should really take a copy to avoid changes in the application.... + self.wsgi_headers = (response_status, response_headers) + return self.wsgi_write_data + + def wsgi_write_data(self, data): + if not self.wsgi_sent_headers: + status, headers = self.wsgi_headers + # Need to send header prior to data + status_code = status[: status.find(" ")] + status_msg = status[status.find(" ") + 1 :] + self.send_response(int(status_code), status_msg) + for header, value in headers: + self.send_header(header, value) + self.end_headers() + self.wsgi_sent_headers = 1 + # Send the data + self.wfile.write(data) + + class WSGIServer(SocketServer.ThreadingMixIn, HTTPServer): + def __init__(self, func, server_address): + HTTPServer.HTTPServer.__init__(self, server_address, WSGIHandler) + self.app = func + self.serverShuttingDown = 0 + + print("http://%s:%d/" % server_address) + WSGIServer(func, server_address).serve_forever() + + +# The WSGIServer instance. +# Made global so that it can be stopped in embedded mode. +server = None + + +def runsimple(func, server_address=("0.0.0.0", 8080)): + """ + Runs [CherryPy][cp] WSGI server hosting WSGI app `func`. + The directory `static/` is hosted statically. + + [cp]: http://www.cherrypy.org + """ + global server + func = StaticMiddleware(func) + func = LogMiddleware(func) + + server = WSGIServer(server_address, func) + + if "/" in server_address[0]: + print("%s" % server_address) + else: + if server.ssl_adapter: + print("https://%s:%d/" % server_address) + else: + print("http://%s:%d/" % server_address) + + try: + server.start() + except (KeyboardInterrupt, SystemExit): + server.stop() + server = None + + +def WSGIServer(server_address, wsgi_app): + """Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`. + This function can be overwritten to customize the webserver or use a different webserver. + """ + from cheroot import wsgi + + server = wsgi.Server(server_address, wsgi_app, server_name="localhost") + server.nodelay = not sys.platform.startswith( + "java" + ) # TCP_NODELAY isn't supported on the JVM + return server + + +class StaticApp(SimpleHTTPRequestHandler): + """WSGI application for serving static files.""" + + def __init__(self, environ, start_response): + self.headers = [] + self.environ = environ + self.start_response = start_response + self.directory = os.getcwd() + + def send_response(self, status, msg=""): + # the int(status) call is needed because in Py3 status is an enum.IntEnum and we need the integer behind + self.status = str(int(status)) + " " + msg + + def send_header(self, name, value): + self.headers.append((name, value)) + + def end_headers(self): + pass + + def log_message(*a): + pass + + def __iter__(self): + environ = self.environ + + self.path = environ.get("PATH_INFO", "") + self.client_address = ( + environ.get("REMOTE_ADDR", "-"), + environ.get("REMOTE_PORT", "-"), + ) + self.command = environ.get("REQUEST_METHOD", "-") + + self.wfile = BytesIO() # for capturing error + + try: + path = self.translate_path(self.path) + etag = '"%s"' % os.path.getmtime(path) + client_etag = environ.get("HTTP_IF_NONE_MATCH") + self.send_header("ETag", etag) + if etag == client_etag: + self.send_response(304, "Not Modified") + self.start_response(self.status, self.headers) + return + except OSError: + pass # Probably a 404 + + f = self.send_head() + self.start_response(self.status, self.headers) + + if f: + block_size = 16 * 1024 + while True: + buf = f.read(block_size) + if not buf: + break + yield buf + f.close() + else: + value = self.wfile.getvalue() + yield value + + +class StaticMiddleware: + """WSGI middleware for serving static files.""" + + def __init__(self, app, prefix="/static/"): + self.app = app + self.prefix = prefix + + def __call__(self, environ, start_response): + path = environ.get("PATH_INFO", "") + path = self.normpath(path) + + if path.startswith(self.prefix): + return StaticApp(environ, start_response) + else: + return self.app(environ, start_response) + + def normpath(self, path): + path2 = posixpath.normpath(unquote(path)) + if path.endswith("/"): + path2 += "/" + return path2 + + +class LogMiddleware: + """WSGI middleware for logging the status.""" + + def __init__(self, app): + self.app = app + self.format = '%s - - [%s] "%s %s %s" - %s' + + f = BytesIO() + + class FakeSocket: + def makefile(self, *a): + return f + + # take log_date_time_string method from BaseHTTPRequestHandler + self.log_date_time_string = BaseHTTPRequestHandler( + FakeSocket(), None, None + ).log_date_time_string + + def __call__(self, environ, start_response): + def xstart_response(status, response_headers, *args): + out = start_response(status, response_headers, *args) + self.log(status, environ) + return out + + return self.app(environ, xstart_response) + + def log(self, status, environ): + outfile = environ.get("wsgi.errors", web.debug) + req = environ.get("PATH_INFO", "_") + protocol = environ.get("ACTUAL_SERVER_PROTOCOL", "-") + method = environ.get("REQUEST_METHOD", "-") + host = "{}:{}".format( + environ.get("REMOTE_ADDR", "-"), + environ.get("REMOTE_PORT", "-"), + ) + + time = self.log_date_time_string() + + msg = self.format % (host, time, protocol, method, req, status) + print(utils.safestr(msg), file=outfile) diff --git a/web/net.py b/web/net.py new file mode 100644 index 0000000..e1c1609 --- /dev/null +++ b/web/net.py @@ -0,0 +1,279 @@ +""" +Network Utilities +(from web.py) +""" + + +import datetime +import re +import socket +import time + +try: + from urllib.parse import quote +except ImportError: + from urllib import quote + +__all__ = [ + "validipaddr", + "validip6addr", + "validipport", + "validip", + "validaddr", + "urlquote", + "httpdate", + "parsehttpdate", + "htmlquote", + "htmlunquote", + "websafe", +] + + +def validip6addr(address): + """ + Returns True if `address` is a valid IPv6 address. + + >>> validip6addr('::') + True + >>> validip6addr('aaaa:bbbb:cccc:dddd::1') + True + >>> validip6addr('1:2:3:4:5:6:7:8:9:10') + False + >>> validip6addr('12:10') + False + """ + try: + socket.inet_pton(socket.AF_INET6, address) + except (OSError, AttributeError, ValueError): + return False + + return True + + +def validipaddr(address): + """ + Returns True if `address` is a valid IPv4 address. + + >>> validipaddr('192.168.1.1') + True + >>> validipaddr('192.168. 1.1') + False + >>> validipaddr('192.168.1.800') + False + >>> validipaddr('192.168.1') + False + """ + try: + octets = address.split(".") + if len(octets) != 4: + return False + + for x in octets: + if " " in x: + return False + + if not (0 <= int(x) <= 255): + return False + except ValueError: + return False + return True + + +def validipport(port): + """ + Returns True if `port` is a valid IPv4 port. + + >>> validipport('9000') + True + >>> validipport('foo') + False + >>> validipport('1000000') + False + """ + try: + if not (0 <= int(port) <= 65535): + return False + except ValueError: + return False + return True + + +def validip(ip, defaultaddr="0.0.0.0", defaultport=8080): + """ + Returns `(ip_address, port)` from string `ip_addr_port` + + >>> validip('1.2.3.4') + ('1.2.3.4', 8080) + >>> validip('80') + ('0.0.0.0', 80) + >>> validip('192.168.0.1:85') + ('192.168.0.1', 85) + >>> validip('::') + ('::', 8080) + >>> validip('[::]:88') + ('::', 88) + >>> validip('[::1]:80') + ('::1', 80) + + """ + addr = defaultaddr + port = defaultport + + # Matt Boswell's code to check for ipv6 first + match = re.search(r"^\[([^]]+)\](?::(\d+))?$", ip) # check for [ipv6]:port + if match: + if validip6addr(match.group(1)): + if match.group(2): + if validipport(match.group(2)): + return (match.group(1), int(match.group(2))) + else: + return (match.group(1), port) + else: + if validip6addr(ip): + return (ip, port) + # end ipv6 code + + ip = ip.split(":", 1) + if len(ip) == 1: + if not ip[0]: + pass + elif validipaddr(ip[0]): + addr = ip[0] + elif validipport(ip[0]): + port = int(ip[0]) + else: + raise ValueError(":".join(ip) + " is not a valid IP address/port") + elif len(ip) == 2: + addr, port = ip + if not validipaddr(addr) or not validipport(port): + raise ValueError(":".join(ip) + " is not a valid IP address/port") + port = int(port) + else: + raise ValueError(":".join(ip) + " is not a valid IP address/port") + return (addr, port) + + +def validaddr(string_): + """ + Returns either (ip_address, port) or "/path/to/socket" from string_ + + >>> validaddr('/path/to/socket') + '/path/to/socket' + >>> validaddr('8000') + ('0.0.0.0', 8000) + >>> validaddr('127.0.0.1') + ('127.0.0.1', 8080) + >>> validaddr('127.0.0.1:8000') + ('127.0.0.1', 8000) + >>> validip('[::1]:80') + ('::1', 80) + >>> validaddr('fff') + Traceback (most recent call last): + ... + ValueError: fff is not a valid IP address/port + """ + if "/" in string_: + return string_ + else: + return validip(string_) + + +def urlquote(val): + """ + Quotes a string for use in a URL. + + >>> urlquote('://?f=1&j=1') + '%3A//%3Ff%3D1%26j%3D1' + >>> urlquote(None) + '' + >>> urlquote(u'\u203d') + '%E2%80%BD' + """ + if val is None: + return "" + + val = str(val).encode("utf-8") + return quote(val) + + +def httpdate(date_obj): + """ + Formats a datetime object for use in HTTP headers. + + >>> import datetime + >>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1)) + 'Thu, 01 Jan 1970 01:01:01 GMT' + """ + return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT") + + +def parsehttpdate(string_): + """ + Parses an HTTP date into a datetime object. + + >>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT') + datetime.datetime(1970, 1, 1, 1, 1, 1) + """ + try: + t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z") + except ValueError: + return None + return datetime.datetime(*t[:6]) + + +def htmlquote(text): + r""" + Encodes `text` for raw use in HTML. + + >>> htmlquote(u"<'&\">") + u'<'&">' + """ + text = text.replace("&", "&") # Must be done first! + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace("'", "'") + text = text.replace('"', """) + return text + + +def htmlunquote(text): + r""" + Decodes `text` that's HTML quoted. + + >>> htmlunquote(u'<'&">') + u'<\'&">' + """ + text = text.replace(""", '"') + text = text.replace("'", "'") + text = text.replace(">", ">") + text = text.replace("<", "<") + text = text.replace("&", "&") # Must be done last! + return text + + +def websafe(val): + r""" + Converts `val` so that it is safe for use in Unicode HTML. + + >>> websafe("<'&\">") + u'<'&">' + >>> websafe(None) + u'' + >>> websafe(u'\u203d') == u'\u203d' + True + """ + if val is None: + return "" + + if isinstance(val, bytes): + val = val.decode("utf-8") + elif not isinstance(val, str): + val = str(val) + + return htmlquote(val) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/web/py3helpers.py b/web/py3helpers.py new file mode 100644 index 0000000..6466e2b --- /dev/null +++ b/web/py3helpers.py @@ -0,0 +1,7 @@ +"""Utilities for make the code run both on Python2 and Python3. +""" + +# Dictionary iteration +iterkeys = lambda d: iter(d.keys()) +itervalues = lambda d: iter(d.values()) +iteritems = lambda d: iter(d.items()) diff --git a/web/session.py b/web/session.py new file mode 100644 index 0000000..61ef7ea --- /dev/null +++ b/web/session.py @@ -0,0 +1,457 @@ +""" +Session Management +(from web.py) +""" + +import datetime +import os +import os.path +import shutil +import threading +import time +from copy import deepcopy +from hashlib import sha1 + +from . import utils +from . import webapi as web +from .py3helpers import iteritems + +try: + import cPickle as pickle +except ImportError: + import pickle + +from base64 import decodebytes, encodebytes + +__all__ = ["Session", "SessionExpired", "Store", "DiskStore", "DBStore", "MemoryStore"] + +web.config.session_parameters = utils.storage( + { + "cookie_name": "webpy_session_id", + "cookie_domain": None, + "cookie_path": None, + "samesite": None, + "timeout": 86400, # 24 * 60 * 60, # 24 hours in seconds + "ignore_expiry": True, + "ignore_change_ip": True, + "secret_key": "fLjUfxqXtfNoIldA0A0J", + "expired_message": "Session expired", + "httponly": True, + "secure": False, + } +) + + +class SessionExpired(web.HTTPError): + def __init__(self, message): + web.HTTPError.__init__(self, "200 OK", {}, data=message) + + +class Session: + """Session management for web.py""" + + __slots__ = [ + "store", + "_initializer", + "_last_cleanup_time", + "_config", + "_data", + "__getitem__", + "__setitem__", + "__delitem__", + ] + + def __init__(self, app, store, initializer=None): + self.store = store + self._initializer = initializer + self._last_cleanup_time = 0 + self._config = utils.storage(web.config.session_parameters) + self._data = utils.threadeddict() + + self.__getitem__ = self._data.__getitem__ + self.__setitem__ = self._data.__setitem__ + self.__delitem__ = self._data.__delitem__ + + if app: + app.add_processor(self._processor) + + def __contains__(self, name): + return name in self._data + + def __getattr__(self, name): + return getattr(self._data, name) + + def __setattr__(self, name, value): + if name in self.__slots__: + object.__setattr__(self, name, value) + else: + setattr(self._data, name, value) + + def __delattr__(self, name): + delattr(self._data, name) + + def _processor(self, handler): + """Application processor to setup session for every request""" + + self._cleanup() + self._load() + + try: + return handler() + finally: + self._save() + + def _load(self): + """Load the session from the store, by the id from cookie""" + cookie_name = self._config.cookie_name + self.session_id = web.cookies().get(cookie_name) + + # protection against session_id tampering + if self.session_id and not self._valid_session_id(self.session_id): + self.session_id = None + + self._check_expiry() + if self.session_id: + d = self.store[self.session_id] + self.update(d) + self._validate_ip() + + if not self.session_id: + self.session_id = self._generate_session_id() + + if self._initializer: + if isinstance(self._initializer, dict): + self.update(deepcopy(self._initializer)) + elif hasattr(self._initializer, "__call__"): + self._initializer() + + self.ip = web.ctx.ip + + def _check_expiry(self): + # check for expiry + if self.session_id and self.session_id not in self.store: + if self._config.ignore_expiry: + self.session_id = None + else: + return self.expired() + + def _validate_ip(self): + # check for change of IP + if self.session_id and self.get("ip", None) != web.ctx.ip: + if not self._config.ignore_change_ip: + return self.expired() + + def _save(self): + current_values = dict(self._data) + del current_values["session_id"] + del current_values["ip"] + + if not self.get("_killed"): + self._setcookie(self.session_id) + self.store[self.session_id] = dict(self._data) + else: + if web.cookies().get(self._config.cookie_name): + self._setcookie(self.session_id, expires=-1) + + def _setcookie(self, session_id, expires="", **kw): + cookie_name = self._config.cookie_name + cookie_domain = self._config.cookie_domain + cookie_path = self._config.cookie_path + httponly = self._config.httponly + secure = self._config.secure + samesite = kw.get("samesite", self._config.get("samesite", None)) + web.setcookie( + cookie_name, + session_id, + expires=expires, + domain=cookie_domain, + httponly=httponly, + secure=secure, + path=cookie_path, + samesite=samesite, + ) + + def _generate_session_id(self): + """Generate a random id for session""" + + while True: + rand = os.urandom(16) + now = time.time() + secret_key = self._config.secret_key + + hashable = f"{rand}{now}{utils.safestr(web.ctx.ip)}{secret_key}" + session_id = sha1(hashable.encode("utf-8")).hexdigest() + if session_id not in self.store: + break + return session_id + + def _valid_session_id(self, session_id): + rx = utils.re_compile("^[0-9a-fA-F]+$") + return rx.match(session_id) + + def _cleanup(self): + """Cleanup the stored sessions""" + current_time = time.time() + timeout = self._config.timeout + if current_time - self._last_cleanup_time > timeout: + self.store.cleanup(timeout) + self._last_cleanup_time = current_time + + def expired(self): + """Called when an expired session is atime""" + self._killed = True + self._save() + raise SessionExpired(self._config.expired_message) + + def kill(self): + """Kill the session, make it no longer available""" + del self.store[self.session_id] + self._killed = True + + +class Store: + """Base class for session stores""" + + def __contains__(self, key): + raise NotImplementedError() + + def __getitem__(self, key): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError() + + def cleanup(self, timeout): + """removes all the expired sessions""" + raise NotImplementedError() + + def encode(self, session_dict): + """encodes session dict as a string""" + pickled = pickle.dumps(session_dict) + return encodebytes(pickled) + + def decode(self, session_data): + """decodes the data to get back the session dict""" + if isinstance(session_data, str): + session_data = session_data.encode() + + pickled = decodebytes(session_data) + return pickle.loads(pickled) + + +class DiskStore(Store): + """ + Store for saving a session on disk. + + >>> import tempfile + >>> root = tempfile.mkdtemp() + >>> s = DiskStore(root) + >>> s['a'] = 'foo' + >>> s['a'] + 'foo' + >>> time.sleep(0.01) + >>> s.cleanup(0.01) + >>> s['a'] + Traceback (most recent call last): + ... + KeyError: 'a' + """ + + def __init__(self, root): + # if the storage root doesn't exists, create it. + if not os.path.exists(root): + os.makedirs(os.path.abspath(root)) + self.root = root + + def _get_path(self, key): + if os.path.sep in key: + raise ValueError("Bad key: %s" % repr(key)) + return os.path.join(self.root, key) + + def __contains__(self, key): + path = self._get_path(key) + return os.path.exists(path) + + def __getitem__(self, key): + path = self._get_path(key) + + if os.path.exists(path): + with open(path, "rb") as fh: + pickled = fh.read() + return self.decode(pickled) + else: + raise KeyError(key) + + def __setitem__(self, key, value): + path = self._get_path(key) + pickled = self.encode(value) + try: + tname = path + "." + threading.current_thread().getName() + f = open(tname, "wb") + try: + f.write(pickled) + finally: + f.close() + shutil.move(tname, path) # atomary operation + except OSError: + pass + + def __delitem__(self, key): + path = self._get_path(key) + if os.path.exists(path): + os.remove(path) + + def cleanup(self, timeout): + if not os.path.isdir(self.root): + return + + now = time.time() + for f in os.listdir(self.root): + path = self._get_path(f) + atime = os.stat(path).st_atime + if now - atime > timeout: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + + +class DBStore(Store): + """Store for saving a session in database + Needs a table with the following columns: + + session_id CHAR(128) UNIQUE NOT NULL, + atime DATETIME NOT NULL default current_timestamp, + data TEXT + """ + + def __init__(self, db, table_name): + self.db = db + self.table = table_name + + def __contains__(self, key): + data = self.db.select(self.table, where="session_id=$key", vars=locals()) + return bool(list(data)) + + def __getitem__(self, key): + now = datetime.datetime.now() + try: + s = self.db.select(self.table, where="session_id=$key", vars=locals())[0] + self.db.update( + self.table, where="session_id=$key", atime=now, vars=locals() + ) + except IndexError: + raise KeyError(key) + else: + return self.decode(s.data) + + def __setitem__(self, key, value): + # Remove the leading `b` of bytes object (`b"..."`), otherwise encoded + # value is invalid base64 format. + pickled = self.encode(value).decode() + + now = datetime.datetime.now() + if key in self: + self.db.update( + self.table, + where="session_id=$key", + data=pickled, + atime=now, + vars=locals(), + ) + else: + self.db.insert(self.table, False, session_id=key, atime=now, data=pickled) + + def __delitem__(self, key): + self.db.delete(self.table, where="session_id=$key", vars=locals()) + + def cleanup(self, timeout): + timeout = datetime.timedelta( + timeout / (24.0 * 60 * 60) + ) # timedelta takes numdays as arg + last_allowed_time = datetime.datetime.now() - timeout + self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals()) + + +class ShelfStore: + """Store for saving session using `shelve` module. + + import shelve + store = ShelfStore(shelve.open('session.shelf')) + + XXX: is shelve thread-safe? + """ + + def __init__(self, shelf): + self.shelf = shelf + + def __contains__(self, key): + return key in self.shelf + + def __getitem__(self, key): + atime, v = self.shelf[key] + self[key] = v # update atime + return v + + def __setitem__(self, key, value): + self.shelf[key] = time.time(), value + + def __delitem__(self, key): + try: + del self.shelf[key] + except KeyError: + pass + + def cleanup(self, timeout): + now = time.time() + for k in self.shelf: + atime, v = self.shelf[k] + if now - atime > timeout: + del self[k] + + +class MemoryStore(Store): + """Store for saving a session in memory. + Useful where there is limited fs writes on the disk, like + flash memories + + Data will be saved into a dict: + k: (time, pydata) + """ + + def __init__(self, d_store=None): + if d_store is None: + d_store = {} + self.d_store = d_store + + def __contains__(self, key): + return key in self.d_store + + def __getitem__(self, key): + """Return the value and update the last seen value""" + t, value = self.d_store[key] + self.d_store[key] = (time.time(), value) + return value + + def __setitem__(self, key, value): + self.d_store[key] = (time.time(), value) + + def __delitem__(self, key): + del self.d_store[key] + + def cleanup(self, timeout): + now = time.time() + to_del = [] + for k, (atime, value) in iteritems(self.d_store): + if now - atime > timeout: + to_del.append(k) + + # to avoid exception on "dict change during iterations" + for k in to_del: + del self.d_store[k] + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/web/template.py b/web/template.py new file mode 100644 index 0000000..d3d47ba --- /dev/null +++ b/web/template.py @@ -0,0 +1,1742 @@ +""" +simple, elegant templating +(part of web.py) + +Template design: + +Template string is split into tokens and the tokens are combined into nodes. +Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and +for-loop, if-loop etc are block nodes, which contain multiple child nodes. + +Each node can emit some python string. python string emitted by the +root node is validated for safeeval and executed using python in the given environment. + +Enough care is taken to make sure the generated code and the template has line to line match, +so that the error messages can point to exact line number in template. (It doesn't work in some cases still.) + +Grammar: + + template -> defwith sections + defwith -> '$def with (' arguments ')' | '' + sections -> section* + section -> block | assignment | line + + assignment -> '$ ' + line -> (text|expr)* + text -> + expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}' + pyexpr -> +""" + +import ast +import builtins +import glob +import os +import sys +import tokenize +from functools import partial + +from .net import websafe +from .utils import re_compile, safestr, safeunicode, storage +from .webapi import config + +__all__ = [ + "Template", + "Render", + "render", + "frender", + "ParseError", + "SecurityError", + "test", +] + + +from collections.abc import MutableMapping + + +def splitline(text): + r""" + Splits the given text at newline. + + >>> splitline('foo\nbar') + ('foo\n', 'bar') + >>> splitline('foo') + ('foo', '') + >>> splitline('') + ('', '') + """ + index = text.find("\n") + 1 + if index: + return text[:index], text[index:] + else: + return text, "" + + +class Parser: + """Parser Base.""" + + def __init__(self): + self.statement_nodes = STATEMENT_NODES + self.keywords = KEYWORDS + + def parse(self, text, name="