Преглед на файлове

Deployment enhancements
* Split traffic rules into distinct config
* Support hot reload of traffic rules, psinet database, and geoip database
* Retire multi-config-file support (won't be used)
* Change fail2ban integration to use main logger (to accommodate Docker integration)

Rod Hynes преди 9 години
родител
ревизия
bffb59d372

+ 34 - 36
Server/main.go

@@ -33,29 +33,23 @@ import (
 
 func main() {
 
+	var generateTrafficRulesFilename, generateServerEntryFilename string
 	var generateServerIPaddress, generateServerNetworkInterface string
-	var generateConfigFilename, generateServerEntryFilename string
 	var generateWebServerPort int
 	var generateProtocolPorts stringListFlag
-	var runConfigFilenames stringListFlag
+	var configFilename string
 
 	flag.StringVar(
-		&generateConfigFilename,
-		"newConfig",
-		server.SERVER_CONFIG_FILENAME,
-		"generate new config with this `filename`")
+		&generateTrafficRulesFilename,
+		"trafficRules",
+		server.SERVER_TRAFFIC_RULES_FILENAME,
+		"generate with this traffic rules `filename`")
 
 	flag.StringVar(
 		&generateServerEntryFilename,
-		"newServerEntry",
+		"serverEntry",
 		server.SERVER_ENTRY_FILENAME,
-		"generate new server entry with this `filename`")
-
-	flag.StringVar(
-		&generateServerNetworkInterface,
-		"interface",
-		"",
-		"generate with server IP address from this `network-interface`")
+		"generate with this server entry `filename`")
 
 	flag.StringVar(
 		&generateServerIPaddress,
@@ -63,6 +57,12 @@ func main() {
 		server.DEFAULT_SERVER_IP_ADDRESS,
 		"generate with this server `IP address`")
 
+	flag.StringVar(
+		&generateServerNetworkInterface,
+		"interface",
+		"",
+		"generate with server IP address from this `network-interface`")
+
 	flag.IntVar(
 		&generateWebServerPort,
 		"web",
@@ -74,15 +74,16 @@ func main() {
 		"protocol",
 		"generate with `protocol:port`; flag may be repeated to enable multiple protocols")
 
-	flag.Var(
-		&runConfigFilenames,
+	flag.StringVar(
+		&configFilename,
 		"config",
-		"run with this config `filename`; flag may be repeated to load multiple config files")
+		server.SERVER_CONFIG_FILENAME,
+		"run or generate with this config `filename`")
 
 	flag.Usage = func() {
 		fmt.Fprintf(os.Stderr,
 			"Usage:\n\n"+
-				"%s <flags> generate    generates a configuration and server entry\n"+
+				"%s <flags> generate    generates configuration files\n"+
 				"%s <flags> run         runs configured services\n\n",
 			os.Args[0], os.Args[0])
 		flag.PrintDefaults()
@@ -119,50 +120,47 @@ func main() {
 			}
 		}
 
-		configFileContents, serverEntryFileContents, err :=
+		configJSON, trafficRulesJSON, encodedServerEntry, err :=
 			server.GenerateConfig(
 				&server.GenerateConfigParams{
 					ServerIPAddress:      serverIPaddress,
 					EnableSSHAPIRequests: true,
 					WebServerPort:        generateWebServerPort,
 					TunnelProtocolPorts:  tunnelProtocolPorts,
+					TrafficRulesFilename: generateTrafficRulesFilename,
 				})
 		if err != nil {
 			fmt.Printf("generate failed: %s\n", err)
 			os.Exit(1)
 		}
 
-		err = ioutil.WriteFile(generateConfigFilename, configFileContents, 0600)
+		err = ioutil.WriteFile(configFilename, configJSON, 0600)
 		if err != nil {
 			fmt.Printf("error writing configuration file: %s\n", err)
 			os.Exit(1)
 		}
 
-		err = ioutil.WriteFile(generateServerEntryFilename, serverEntryFileContents, 0600)
+		err = ioutil.WriteFile(generateTrafficRulesFilename, trafficRulesJSON, 0600)
 		if err != nil {
-			fmt.Printf("error writing server entry file: %s\n", err)
+			fmt.Printf("error writing traffic rule configuration file: %s\n", err)
 			os.Exit(1)
 		}
 
-	} else if args[0] == "run" {
-
-		if len(runConfigFilenames) == 0 {
-			runConfigFilenames = []string{server.SERVER_CONFIG_FILENAME}
+		err = ioutil.WriteFile(generateServerEntryFilename, encodedServerEntry, 0600)
+		if err != nil {
+			fmt.Printf("error writing server entry file: %s\n", err)
+			os.Exit(1)
 		}
 
-		var configFileContents [][]byte
-
-		for _, configFilename := range runConfigFilenames {
-			contents, err := ioutil.ReadFile(configFilename)
-			if err != nil {
-				fmt.Printf("error loading configuration file: %s\n", err)
-				os.Exit(1)
-			}
+	} else if args[0] == "run" {
 
-			configFileContents = append(configFileContents, contents)
+		configJSON, err := ioutil.ReadFile(configFilename)
+		if err != nil {
+			fmt.Printf("error loading configuration file: %s\n", err)
+			os.Exit(1)
 		}
 
-		err := server.RunServices(configFileContents)
+		err = server.RunServices(configJSON)
 		if err != nil {
 			fmt.Printf("run failed: %s\n", err)
 			os.Exit(1)

+ 56 - 51
psiphon/server/api.go

@@ -31,7 +31,6 @@ import (
 	"unicode"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 )
 
 const MAX_API_PARAMS_SIZE = 256 * 1024 // 256KB
@@ -50,8 +49,7 @@ type requestJSONObject map[string]interface{}
 // clients.
 //
 func sshAPIRequestHandler(
-	config *Config,
-	psinetDatabase *psinet.Database,
+	support *SupportServices,
 	geoIPData GeoIPData,
 	name string,
 	requestPayload []byte) ([]byte, error) {
@@ -67,13 +65,13 @@ func sshAPIRequestHandler(
 
 	switch name {
 	case psiphon.SERVER_API_HANDSHAKE_REQUEST_NAME:
-		return handshakeAPIRequestHandler(config, psinetDatabase, geoIPData, params)
+		return handshakeAPIRequestHandler(support, geoIPData, params)
 	case psiphon.SERVER_API_CONNECTED_REQUEST_NAME:
-		return connectedAPIRequestHandler(config, geoIPData, params)
+		return connectedAPIRequestHandler(support, geoIPData, params)
 	case psiphon.SERVER_API_STATUS_REQUEST_NAME:
-		return statusAPIRequestHandler(config, geoIPData, params)
+		return statusAPIRequestHandler(support, geoIPData, params)
 	case psiphon.SERVER_API_CLIENT_VERIFICATION_REQUEST_NAME:
-		return clientVerificationAPIRequestHandler(config, geoIPData, params)
+		return clientVerificationAPIRequestHandler(support, geoIPData, params)
 	}
 
 	return nil, psiphon.ContextError(fmt.Errorf("invalid request name: %s", name))
@@ -84,14 +82,13 @@ func sshAPIRequestHandler(
 // connection; the response tells the client what homepage to open, what
 // stats to record, etc.
 func handshakeAPIRequestHandler(
-	config *Config,
-	psinetDatabase *psinet.Database,
+	support *SupportServices,
 	geoIPData GeoIPData,
 	params requestJSONObject) ([]byte, error) {
 
 	// Note: ignoring "known_servers" params
 
-	err := validateRequestParams(config, params, baseRequestParams)
+	err := validateRequestParams(support, params, baseRequestParams)
 	if err != nil {
 		// TODO: fail2ban?
 		return nil, psiphon.ContextError(errors.New("invalid params"))
@@ -99,7 +96,7 @@ func handshakeAPIRequestHandler(
 
 	log.WithContextFields(
 		getRequestLogFields(
-			config,
+			support,
 			"handshake",
 			geoIPData,
 			params,
@@ -123,16 +120,18 @@ func handshakeAPIRequestHandler(
 	clientPlatform, _ := getStringRequestParam(params, "client_platform")
 	clientRegion := geoIPData.Country
 
-	handshakeResponse.Homepages = psinetDatabase.GetHomepages(
+	// Note: no guarantee that PsinetDatabase won't reload between calls
+
+	handshakeResponse.Homepages = support.PsinetDatabase.GetHomepages(
 		sponsorID, clientRegion, clientPlatform)
 
-	handshakeResponse.UpgradeClientVersion = psinetDatabase.GetUpgradeClientVersion(
+	handshakeResponse.UpgradeClientVersion = support.PsinetDatabase.GetUpgradeClientVersion(
 		clientVersion, clientPlatform)
 
-	handshakeResponse.HttpsRequestRegexes = psinetDatabase.GetHttpsRequestRegexes(
+	handshakeResponse.HttpsRequestRegexes = support.PsinetDatabase.GetHttpsRequestRegexes(
 		sponsorID)
 
-	handshakeResponse.EncodedServerList = psinetDatabase.DiscoverServers(
+	handshakeResponse.EncodedServerList = support.PsinetDatabase.DiscoverServers(
 		propagationChannelID, geoIPData.DiscoveryValue)
 
 	handshakeResponse.ClientRegion = clientRegion
@@ -159,9 +158,11 @@ var connectedRequestParams = append(
 // which should be a connected_timestamp output from a previous connected
 // response, is used to calculate unique user stats.
 func connectedAPIRequestHandler(
-	config *Config, geoIPData GeoIPData, params requestJSONObject) ([]byte, error) {
+	support *SupportServices,
+	geoIPData GeoIPData,
+	params requestJSONObject) ([]byte, error) {
 
-	err := validateRequestParams(config, params, connectedRequestParams)
+	err := validateRequestParams(support, params, connectedRequestParams)
 	if err != nil {
 		// TODO: fail2ban?
 		return nil, psiphon.ContextError(errors.New("invalid params"))
@@ -169,7 +170,7 @@ func connectedAPIRequestHandler(
 
 	log.WithContextFields(
 		getRequestLogFields(
-			config,
+			support,
 			"connected",
 			geoIPData,
 			params,
@@ -203,9 +204,11 @@ var statusRequestParams = append(
 // any string is accepted (regex transform may result in arbitrary
 // string). Stats processor must handle this input with care.
 func statusAPIRequestHandler(
-	config *Config, geoIPData GeoIPData, params requestJSONObject) ([]byte, error) {
+	support *SupportServices,
+	geoIPData GeoIPData,
+	params requestJSONObject) ([]byte, error) {
 
-	err := validateRequestParams(config, params, statusRequestParams)
+	err := validateRequestParams(support, params, statusRequestParams)
 	if err != nil {
 		// TODO: fail2ban?
 		return nil, psiphon.ContextError(errors.New("invalid params"))
@@ -223,7 +226,7 @@ func statusAPIRequestHandler(
 		return nil, psiphon.ContextError(err)
 	}
 	bytesTransferredFields := getRequestLogFields(
-		config, "bytes_transferred", geoIPData, params, statusRequestParams)
+		support, "bytes_transferred", geoIPData, params, statusRequestParams)
 	bytesTransferredFields["bytes"] = bytesTransferred
 	log.WithContextFields(bytesTransferredFields).Info("API event")
 
@@ -237,7 +240,7 @@ func statusAPIRequestHandler(
 			return nil, psiphon.ContextError(err)
 		}
 		domainBytesFields := getRequestLogFields(
-			config, "domain_bytes", geoIPData, params, statusRequestParams)
+			support, "domain_bytes", geoIPData, params, statusRequestParams)
 		for domain, bytes := range hostBytes {
 			domainBytesFields["domain"] = domain
 			domainBytesFields["bytes"] = bytes
@@ -255,7 +258,7 @@ func statusAPIRequestHandler(
 			return nil, psiphon.ContextError(err)
 		}
 		sessionFields := getRequestLogFields(
-			config, "session", geoIPData, params, statusRequestParams)
+			support, "session", geoIPData, params, statusRequestParams)
 		for _, tunnelStat := range tunnelStats {
 
 			sessionID, err := getStringRequestParam(tunnelStat, "session_id")
@@ -317,9 +320,11 @@ func statusAPIRequestHandler(
 // verification request once per tunnel connection. The payload
 // attests that client is a legitimate Psiphon client.
 func clientVerificationAPIRequestHandler(
-	config *Config, geoIPData GeoIPData, params requestJSONObject) ([]byte, error) {
+	support *SupportServices,
+	geoIPData GeoIPData,
+	params requestJSONObject) ([]byte, error) {
 
-	err := validateRequestParams(config, params, baseRequestParams)
+	err := validateRequestParams(support, params, baseRequestParams)
 	if err != nil {
 		// TODO: fail2ban?
 		return nil, psiphon.ContextError(errors.New("invalid params"))
@@ -332,7 +337,7 @@ func clientVerificationAPIRequestHandler(
 
 type requestParamSpec struct {
 	name      string
-	validator func(*Config, string) bool
+	validator func(*SupportServices, string) bool
 	flags     uint32
 }
 
@@ -365,7 +370,7 @@ var baseRequestParams = []requestParamSpec{
 }
 
 func validateRequestParams(
-	config *Config,
+	support *SupportServices,
 	params requestJSONObject,
 	expectedParams []requestParamSpec) error {
 
@@ -383,7 +388,7 @@ func validateRequestParams(
 			return psiphon.ContextError(
 				fmt.Errorf("unexpected param type: %s", expectedParam.name))
 		}
-		if !expectedParam.validator(config, strValue) {
+		if !expectedParam.validator(support, strValue) {
 			return psiphon.ContextError(
 				fmt.Errorf("invalid param: %s", expectedParam.name))
 		}
@@ -395,7 +400,7 @@ func validateRequestParams(
 // getRequestLogFields makes LogFields to log the API event following
 // the legacy psi_web and current ELK naming conventions.
 func getRequestLogFields(
-	config *Config,
+	support *SupportServices,
 	eventName string,
 	geoIPData GeoIPData,
 	params requestJSONObject,
@@ -404,7 +409,7 @@ func getRequestLogFields(
 	logFields := make(LogFields)
 
 	logFields["event_name"] = eventName
-	logFields["host_id"] = config.HostID
+	logFields["host_id"] = support.Config.HostID
 
 	// In psi_web, the space replacement was done to accommodate space
 	// delimited logging, which is no longer required; we retain the
@@ -449,7 +454,7 @@ func getRequestLogFields(
 			logFields[expectedParam.name] = intValue
 		case "meek_dial_address":
 			host, _, _ := net.SplitHostPort(strValue)
-			if isIPAddress(config, host) {
+			if isIPAddress(support, host) {
 				logFields["meek_dial_ip_address"] = host
 			} else {
 				logFields["meek_dial_domain"] = host
@@ -544,40 +549,40 @@ func getMapStringInt64RequestParam(params requestJSONObject, name string) (map[s
 
 // Input validators follow the legacy validations rules in psi_web.
 
-func isServerSecret(config *Config, value string) bool {
+func isServerSecret(support *SupportServices, value string) bool {
 	return subtle.ConstantTimeCompare(
 		[]byte(value),
-		[]byte(config.WebServerSecret)) == 1
+		[]byte(support.Config.WebServerSecret)) == 1
 }
 
-func isHexDigits(_ *Config, value string) bool {
+func isHexDigits(_ *SupportServices, value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		return !unicode.Is(unicode.ASCII_Hex_Digit, c)
 	})
 }
 
-func isDigits(_ *Config, value string) bool {
+func isDigits(_ *SupportServices, value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		return c < '0' || c > '9'
 	})
 }
 
-func isClientPlatform(_ *Config, value string) bool {
+func isClientPlatform(_ *SupportServices, value string) bool {
 	return -1 == strings.IndexFunc(value, func(c rune) bool {
 		// Note: stricter than psi_web's Python string.whitespace
 		return unicode.Is(unicode.White_Space, c)
 	})
 }
 
-func isRelayProtocol(_ *Config, value string) bool {
+func isRelayProtocol(_ *SupportServices, value string) bool {
 	return psiphon.Contains(psiphon.SupportedTunnelProtocols, value)
 }
 
-func isBooleanFlag(_ *Config, value string) bool {
+func isBooleanFlag(_ *SupportServices, value string) bool {
 	return value == "0" || value == "1"
 }
 
-func isRegionCode(_ *Config, value string) bool {
+func isRegionCode(_ *SupportServices, value string) bool {
 	if len(value) != 2 {
 		return false
 	}
@@ -586,16 +591,16 @@ func isRegionCode(_ *Config, value string) bool {
 	})
 }
 
-func isDialAddress(config *Config, value string) bool {
+func isDialAddress(support *SupportServices, value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address
 	parts := strings.Split(value, ":")
 	if len(parts) != 2 {
 		return false
 	}
-	if !isIPAddress(config, parts[0]) && !isDomain(config, parts[0]) {
+	if !isIPAddress(support, parts[0]) && !isDomain(support, parts[0]) {
 		return false
 	}
-	if !isDigits(config, parts[1]) {
+	if !isDigits(support, parts[1]) {
 		return false
 	}
 	port, err := strconv.Atoi(parts[1])
@@ -605,13 +610,13 @@ func isDialAddress(config *Config, value string) bool {
 	return port > 0 && port < 65536
 }
 
-func isIPAddress(_ *Config, value string) bool {
+func isIPAddress(_ *SupportServices, value string) bool {
 	return net.ParseIP(value) != nil
 }
 
 var isDomainRegex = regexp.MustCompile("[a-zA-Z\\d-]{1,63}$")
 
-func isDomain(_ *Config, value string) bool {
+func isDomain(_ *SupportServices, value string) bool {
 
 	// From: http://stackoverflow.com/questions/2532053/validate-a-hostname-string
 	//
@@ -638,25 +643,25 @@ func isDomain(_ *Config, value string) bool {
 	return true
 }
 
-func isHostHeader(config *Config, value string) bool {
+func isHostHeader(support *SupportServices, value string) bool {
 	// "<host>:<port>", where <host> is a domain or IP address and ":<port>" is optional
 	if strings.Contains(value, ":") {
-		return isDialAddress(config, value)
+		return isDialAddress(support, value)
 	}
-	return isIPAddress(config, value) || isDomain(config, value)
+	return isIPAddress(support, value) || isDomain(support, value)
 }
 
-func isServerEntrySource(_ *Config, value string) bool {
+func isServerEntrySource(_ *SupportServices, value string) bool {
 	return psiphon.Contains(psiphon.SupportedServerEntrySources, value)
 }
 
 var isISO8601DateRegex = regexp.MustCompile(
 	"(?P<year>[0-9]{4})-(?P<month>[0-9]{1,2})-(?P<day>[0-9]{1,2})T(?P<hour>[0-9]{2}):(?P<minute>[0-9]{2}):(?P<second>[0-9]{2})(\\.(?P<fraction>[0-9]+))?(?P<timezone>Z|(([-+])([0-9]{2}):([0-9]{2})))")
 
-func isISO8601Date(_ *Config, value string) bool {
+func isISO8601Date(_ *SupportServices, value string) bool {
 	return isISO8601DateRegex.Match([]byte(value))
 }
 
-func isLastConnected(config *Config, value string) bool {
-	return value == "None" || value == "Unknown" || isISO8601Date(config, value)
+func isLastConnected(support *SupportServices, value string) bool {
+	return value == "None" || value == "Unknown" || isISO8601Date(support, value)
 }

+ 55 - 166
psiphon/server/config.go

@@ -39,8 +39,9 @@ import (
 )
 
 const (
-	SERVER_CONFIG_FILENAME                = "psiphon-server.config"
-	SERVER_ENTRY_FILENAME                 = "serverEntry.dat"
+	SERVER_CONFIG_FILENAME                = "psiphond.config"
+	SERVER_TRAFFIC_RULES_FILENAME         = "psiphond-traffic-rules.config"
+	SERVER_ENTRY_FILENAME                 = "server-entry.dat"
 	DEFAULT_SERVER_IP_ADDRESS             = "127.0.0.1"
 	WEB_SERVER_SECRET_BYTE_LENGTH         = 32
 	DISCOVERY_VALUE_KEY_BYTE_LENGTH       = 32
@@ -57,8 +58,6 @@ const (
 	GEOIP_SESSION_CACHE_TTL               = 60 * time.Minute
 )
 
-// TODO: break config into sections (sub-structs)
-
 // Config specifies the configuration and behavior of a Psiphon
 // server.
 type Config struct {
@@ -81,9 +80,8 @@ type Config struct {
 	// Fail2BanFormat is a string format specifier for the
 	// log message format to use for fail2ban integration for
 	// blocking abusive clients by source IP address.
-	// When set, logs with this format are made to the AUTH
-	// facility with INFO severity in the local syslog server
-	// if clients fail to authenticate.
+	// When set, logs with this format are made if clients fail
+	// to authenticate.
 	// The client's IP address is included with the log message.
 	// An example format specifier, which should be compatible
 	// with default SSH fail2ban configuration, is
@@ -203,100 +201,16 @@ type Config struct {
 	// tunneled DNS UDP packets will be re-routed to this destination.
 	UDPForwardDNSServerAddress string
 
-	// DefaultTrafficRules specifies the traffic rules to be used when
-	// no regional-specific rules are set.
-	DefaultTrafficRules TrafficRules
-
-	// RegionalTrafficRules specifies the traffic rules for particular
-	// client regions (countries) as determined by GeoIP lookup of the
-	// client IP address. The key for each regional traffic rule entry
-	// is one or more space delimited ISO 3166-1 alpha-2 country codes.
-	RegionalTrafficRules map[string]TrafficRules
-
 	// LoadMonitorPeriodSeconds indicates how frequently to log server
 	// load information (number of connected clients per tunnel protocol,
 	// number of running goroutines, amount of memory allocated, etc.)
 	// The default, 0, disables load logging.
 	LoadMonitorPeriodSeconds int
-}
-
-// RateLimits specify the rate limits for tunneled data transfer
-// between an individual client and the server.
-type RateLimits struct {
 
-	// DownstreamUnlimitedBytes specifies the number of downstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	DownstreamUnlimitedBytes int64
-
-	// DownstreamBytesPerSecond specifies a rate limit for downstream
-	// data transfer. The default, 0, is no limit.
-	DownstreamBytesPerSecond int
-
-	// UpstreamUnlimitedBytes specifies the number of upstream
-	// bytes to transfer, approximately, before starting rate
-	// limiting.
-	UpstreamUnlimitedBytes int64
-
-	// UpstreamBytesPerSecond specifies a rate limit for upstream
-	// data transfer. The default, 0, is no limit.
-	UpstreamBytesPerSecond int
-}
-
-// TrafficRules specify the limits placed on client traffic.
-type TrafficRules struct {
-	// DefaultRateLimitsare the rate limits to be applied when
-	// no protocol-specific rates are set.
-	DefaultRateLimits RateLimits
-
-	// ProtocolRateLimits specifies the rate limits for particular
-	// tunnel protocols. The key for each rate limit entry is one
-	// or more space delimited Psiphon tunnel protocol names. Valid
-	// tunnel protocols includes the same list as for
-	// TunnelProtocolPorts.
-	ProtocolRateLimits map[string]RateLimits
-
-	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
-	// after which idle (no bytes flowing in either direction)
-	// client TCP port forwards are preemptively closed.
-	// The default, 0, is no idle timeout.
-	IdleTCPPortForwardTimeoutMilliseconds int
-
-	// IdleUDPPortForwardTimeoutMilliseconds is the timeout period
-	// after which idle (no bytes flowing in either direction)
-	// client UDP port forwards are preemptively closed.
-	// The default, 0, is no idle timeout.
-	IdleUDPPortForwardTimeoutMilliseconds int
-
-	// MaxTCPPortForwardCount is the maximum number of TCP port
-	// forwards each client may have open concurrently.
-	// The default, 0, is no maximum.
-	MaxTCPPortForwardCount int
-
-	// MaxUDPPortForwardCount is the maximum number of UDP port
-	// forwards each client may have open concurrently.
-	// The default, 0, is no maximum.
-	MaxUDPPortForwardCount int
-
-	// AllowTCPPorts specifies a whitelist of TCP ports that
-	// are permitted for port forwarding. When set, only ports
-	// in the list are accessible to clients.
-	AllowTCPPorts []int
-
-	// AllowUDPPorts specifies a whitelist of UDP ports that
-	// are permitted for port forwarding. When set, only ports
-	// in the list are accessible to clients.
-	AllowUDPPorts []int
-
-	// DenyTCPPorts specifies a blacklist of TCP ports that
-	// are not permitted for port forwarding. When set, the
-	// ports in the list are inaccessible to clients.
-	DenyTCPPorts []int
-
-	// DenyUDPPorts specifies a blacklist of UDP ports that
-	// are not permitted for port forwarding. When set, the
-	// ports in the list are inaccessible to clients.
-	DenyUDPPorts []int
+	// TrafficRulesFilename is the path of a file containing a
+	// JSON-encoded TrafficRulesSet, the traffic rules to apply to
+	// Psiphon client tunnels.
+	TrafficRulesFilename string
 }
 
 // RunWebServer indicates whether to run a web server component.
@@ -315,49 +229,13 @@ func (config *Config) UseFail2Ban() bool {
 	return config.Fail2BanFormat != ""
 }
 
-// GetTrafficRules looks up the traffic rules for the specified country. If there
-// are no RegionalTrafficRules for the country, DefaultTrafficRules are used.
-func (config *Config) GetTrafficRules(clientCountryCode string) TrafficRules {
-	// TODO: faster lookup?
-	for countryCodes, trafficRules := range config.RegionalTrafficRules {
-		for _, countryCode := range strings.Split(countryCodes, " ") {
-			if countryCode == clientCountryCode {
-				return trafficRules
-			}
-		}
-	}
-	return config.DefaultTrafficRules
-}
+// LoadConfig loads and validates a JSON encoded server config.
+func LoadConfig(configJSON []byte) (*Config, error) {
 
-// GetRateLimits looks up the rate limits for the specified tunnel protocol.
-// If there are no ProtocolRateLimits for the protocol, DefaultRateLimits are used.
-func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) RateLimits {
-	// TODO: faster lookup?
-	for tunnelProtocols, rateLimits := range rules.ProtocolRateLimits {
-		for _, tunnelProtocol := range strings.Split(tunnelProtocols, " ") {
-			if tunnelProtocol == clientTunnelProtocol {
-				return rateLimits
-			}
-		}
-	}
-	return rules.DefaultRateLimits
-}
-
-// LoadConfig loads and validates a JSON encoded server config. If more than one
-// JSON config is specified, then all are loaded and values are merged together,
-// in order. Multiple configs allows for use cases like storing static, server-specific
-// values in a base config while also deploying network-wide throttling settings
-// in a secondary file that can be paved over on all server hosts.
-func LoadConfig(configJSONs [][]byte) (*Config, error) {
-
-	// Note: default values are set in GenerateConfig
 	var config Config
-
-	for _, configJSON := range configJSONs {
-		err := json.Unmarshal(configJSON, &config)
-		if err != nil {
-			return nil, psiphon.ContextError(err)
-		}
+	err := json.Unmarshal(configJSON, &config)
+	if err != nil {
+		return nil, psiphon.ContextError(err)
 	}
 
 	if config.Fail2BanFormat != "" && strings.Count(config.Fail2BanFormat, "%s") != 1 {
@@ -442,25 +320,26 @@ type GenerateConfigParams struct {
 	WebServerPort        int
 	EnableSSHAPIRequests bool
 	TunnelProtocolPorts  map[string]int
+	TrafficRulesFilename string
 }
 
-// GenerateConfig creates a new Psiphon server config. It returns a JSON
-// encoded config and a client-compatible "server entry" for the server. It
+// GenerateConfig creates a new Psiphon server config. It returns JSON
+// encoded configs and a client-compatible "server entry" for the server. It
 // generates all necessary secrets and key material, which are emitted in
 // the config file and server entry as necessary.
 // GenerateConfig uses sample values for many fields. The intention is for
-// a generated config to be used for testing or as a template for production
+// generated configs to be used for testing or as a template for production
 // setup, not to generate production-ready configurations.
-func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
+func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, []byte, error) {
 
 	// Input validation
 
 	if net.ParseIP(params.ServerIPAddress) == nil {
-		return nil, nil, psiphon.ContextError(errors.New("invalid IP address"))
+		return nil, nil, nil, psiphon.ContextError(errors.New("invalid IP address"))
 	}
 
 	if len(params.TunnelProtocolPorts) == 0 {
-		return nil, nil, psiphon.ContextError(errors.New("no tunnel protocols"))
+		return nil, nil, nil, psiphon.ContextError(errors.New("no tunnel protocols"))
 	}
 
 	usedPort := make(map[int]bool)
@@ -473,11 +352,11 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 	for protocol, port := range params.TunnelProtocolPorts {
 
 		if !psiphon.Contains(psiphon.SupportedTunnelProtocols, protocol) {
-			return nil, nil, psiphon.ContextError(errors.New("invalid tunnel protocol"))
+			return nil, nil, nil, psiphon.ContextError(errors.New("invalid tunnel protocol"))
 		}
 
 		if usedPort[port] {
-			return nil, nil, psiphon.ContextError(errors.New("duplicate listening port"))
+			return nil, nil, nil, psiphon.ContextError(errors.New("duplicate listening port"))
 		}
 		usedPort[port] = true
 
@@ -495,12 +374,12 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 		var err error
 		webServerSecret, err = psiphon.MakeRandomStringHex(WEB_SERVER_SECRET_BYTE_LENGTH)
 		if err != nil {
-			return nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, psiphon.ContextError(err)
 		}
 
 		webServerCertificate, webServerPrivateKey, err = GenerateWebServerCertificate("")
 		if err != nil {
-			return nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, psiphon.ContextError(err)
 		}
 	}
 
@@ -509,7 +388,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 	// TODO: use other key types: anti-fingerprint by varying params
 	rsaKey, err := rsa.GenerateKey(rand.Reader, SSH_RSA_HOST_KEY_BITS)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
 	sshPrivateKey := pem.EncodeToMemory(
@@ -521,21 +400,21 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 	signer, err := ssh.NewSignerFromKey(rsaKey)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
 	sshPublicKey := signer.PublicKey()
 
 	sshUserNameSuffix, err := psiphon.MakeRandomStringHex(SSH_USERNAME_SUFFIX_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
 	sshUserName := "psiphon_" + sshUserNameSuffix
 
 	sshPassword, err := psiphon.MakeRandomStringHex(SSH_PASSWORD_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
 	// TODO: vary version string for anti-fingerprint
@@ -545,7 +424,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 	obfuscatedSSHKey, err := psiphon.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
 	// Meek config
@@ -556,7 +435,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 		rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err :=
 			box.GenerateKey(rand.Reader)
 		if err != nil {
-			return nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, psiphon.ContextError(err)
 		}
 
 		meekCookieEncryptionPublicKey = base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPublicKey[:])
@@ -564,7 +443,7 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 		meekObfuscatedKey, err = psiphon.MakeRandomStringHex(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
 		if err != nil {
-			return nil, nil, psiphon.ContextError(err)
+			return nil, nil, nil, psiphon.ContextError(err)
 		}
 	}
 
@@ -572,10 +451,10 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 	discoveryValueHMACKey, err := psiphon.MakeRandomStringBase64(DISCOVERY_VALUE_KEY_BYTE_LENGTH)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
-	// Assemble config and server entry
+	// Assemble configs and server entry
 
 	// Note: this config is intended for either testing or as an illustrative
 	// example or template and is not intended for production deployment.
@@ -606,8 +485,18 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 		MeekCertificateCommonName:      "www.example.org",
 		MeekProhibitedHeaders:          nil,
 		MeekProxyForwardedForHeaders:   []string{"X-Forwarded-For"},
-		DefaultTrafficRules: TrafficRules{
-			DefaultRateLimits: RateLimits{
+		LoadMonitorPeriodSeconds:       300,
+		TrafficRulesFilename:           params.TrafficRulesFilename,
+	}
+
+	encodedConfig, err := json.MarshalIndent(config, "\n", "    ")
+	if err != nil {
+		return nil, nil, nil, psiphon.ContextError(err)
+	}
+
+	trafficRulesSet := &TrafficRulesSet{
+		DefaultRules: TrafficRules{
+			DefaultLimits: RateLimits{
 				DownstreamUnlimitedBytes: 0,
 				DownstreamBytesPerSecond: 0,
 				UpstreamUnlimitedBytes:   0,
@@ -622,18 +511,13 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 			DenyTCPPorts:                          nil,
 			DenyUDPPorts:                          nil,
 		},
-		LoadMonitorPeriodSeconds: 300,
 	}
 
-	encodedConfig, err := json.MarshalIndent(config, "\n", "    ")
+	encodedTrafficRulesSet, err := json.MarshalIndent(trafficRulesSet, "\n", "    ")
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
-	// Server entry format omits the BEGIN/END lines and newlines
-	lines := strings.Split(webServerCertificate, "\n")
-	strippedWebServerCertificate := strings.Join(lines[1:len(lines)-2], "")
-
 	capabilities := []string{}
 
 	if params.EnableSSHAPIRequests {
@@ -664,9 +548,14 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 	// a fronting hop.
 
 	serverEntryWebServerPort := ""
+	strippedWebServerCertificate := ""
 
 	if params.WebServerPort != 0 {
 		serverEntryWebServerPort = fmt.Sprintf("%d", params.WebServerPort)
+
+		// Server entry format omits the BEGIN/END lines and newlines
+		lines := strings.Split(webServerCertificate, "\n")
+		strippedWebServerCertificate = strings.Join(lines[1:len(lines)-2], "")
 	}
 
 	serverEntry := &psiphon.ServerEntry{
@@ -692,8 +581,8 @@ func GenerateConfig(params *GenerateConfigParams) ([]byte, []byte, error) {
 
 	encodedServerEntry, err := psiphon.EncodeServerEntry(serverEntry)
 	if err != nil {
-		return nil, nil, psiphon.ContextError(err)
+		return nil, nil, nil, psiphon.ContextError(err)
 	}
 
-	return encodedConfig, []byte(encodedServerEntry), nil
+	return encodedConfig, encodedTrafficRulesSet, []byte(encodedServerEntry), nil
 }

+ 62 - 43
psiphon/server/geoip.go

@@ -23,6 +23,7 @@ import (
 	"crypto/hmac"
 	"crypto/sha256"
 	"net"
+	"sync"
 	"time"
 
 	cache "github.com/Psiphon-Inc/go-cache"
@@ -32,7 +33,7 @@ import (
 
 const UNKNOWN_GEOIP_VALUE = "None"
 
-// GeoIPData stores GeoIP data for a client session. Individual client
+// GeoIPData is GeoIP data for a client session. Individual client
 // IP addresses are neither logged nor explicitly referenced during a session.
 // The GeoIP country, city, and ISP corresponding to a client IP address are
 // resolved and then logged along with usage stats. The DiscoveryValue is
@@ -55,14 +56,62 @@ func NewGeoIPData() GeoIPData {
 	}
 }
 
-// GeoIPLookup determines a GeoIPData for a given client IP address.
-func GeoIPLookup(ipAddress string) GeoIPData {
+// GeoIPService implements GeoIP lookup and session/GeoIP caching.
+// Lookup is via a MaxMind database; the ReloadDatabase function
+// supports hot reloading of MaxMind data while the server is
+// running.
+type GeoIPService struct {
+	maxMindReadeMutex     sync.RWMutex
+	maxMindReader         *maxminddb.Reader
+	sessionCache          *cache.Cache
+	discoveryValueHMACKey string
+}
+
+// NewGeoIPService initializes a new GeoIPService.
+func NewGeoIPService(databaseFilename, discoveryValueHMACKey string) (*GeoIPService, error) {
+	geoIP := &GeoIPService{
+		maxMindReader:         nil,
+		sessionCache:          cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute),
+		discoveryValueHMACKey: discoveryValueHMACKey,
+	}
+	return geoIP, geoIP.ReloadDatabase(databaseFilename)
+}
+
+// ReloadDatabase [re]loads a MaxMind GeoIP2/GeoLite2 database to
+// be used for GeoIP lookup. When ReloadDatabase fails, the previous
+// MaxMind database state is retained.
+// ReloadDatabase only updates the MaxMind database and doesn't affect
+// other GeopIPService components (e.g., the session cache).
+func (geoIP *GeoIPService) ReloadDatabase(databaseFilename string) error {
+	geoIP.maxMindReadeMutex.Lock()
+	defer geoIP.maxMindReadeMutex.Unlock()
+
+	if databaseFilename == "" {
+		// No database filename in the config
+		return nil
+	}
+
+	maxMindReader, err := maxminddb.Open(databaseFilename)
+	if err != nil {
+		return psiphon.ContextError(err)
+	}
+
+	geoIP.maxMindReader = maxMindReader
+
+	return nil
+}
+
+// Lookup determines a GeoIPData for a given client IP address.
+func (geoIP *GeoIPService) Lookup(ipAddress string) GeoIPData {
+	geoIP.maxMindReadeMutex.RLock()
+	defer geoIP.maxMindReadeMutex.RUnlock()
 
 	result := NewGeoIPData()
 
 	ip := net.ParseIP(ipAddress)
 
-	if ip == nil || geoIPReader == nil {
+	// Note: maxMindReader is nil when config.GeoIPDatabaseFilename is blank.
+	if ip == nil || geoIP.maxMindReader == nil {
 		return result
 	}
 
@@ -76,7 +125,7 @@ func GeoIPLookup(ipAddress string) GeoIPData {
 		ISP string `maxminddb:"isp"`
 	}
 
-	err := geoIPReader.Lookup(ip, &geoIPFields)
+	err := geoIP.maxMindReader.Lookup(ip, &geoIPFields)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Warning("GeoIP lookup failed")
 	}
@@ -94,23 +143,19 @@ func GeoIPLookup(ipAddress string) GeoIPData {
 		result.ISP = geoIPFields.ISP
 	}
 
-	result.DiscoveryValue = calculateDiscoveryValue(ipAddress)
+	result.DiscoveryValue = calculateDiscoveryValue(
+		geoIP.discoveryValueHMACKey, ipAddress)
 
 	return result
 }
 
-func SetGeoIPSessionCache(sessionID string, geoIPData GeoIPData) {
-	if geoIPSessionCache == nil {
-		return
-	}
-	geoIPSessionCache.Set(sessionID, geoIPData, cache.DefaultExpiration)
+func (geoIP *GeoIPService) SetSessionCache(sessionID string, geoIPData GeoIPData) {
+	geoIP.sessionCache.Set(sessionID, geoIPData, cache.DefaultExpiration)
 }
 
-func GetGeoIPSessionCache(sessionID string) GeoIPData {
-	if geoIPSessionCache == nil {
-		return NewGeoIPData()
-	}
-	geoIPData, found := geoIPSessionCache.Get(sessionID)
+func (geoIP *GeoIPService) GetSessionCache(
+	sessionID string) GeoIPData {
+	geoIPData, found := geoIP.sessionCache.Get(sessionID)
 	if !found {
 		return NewGeoIPData()
 	}
@@ -123,7 +168,7 @@ func GetGeoIPSessionCache(sessionID string) GeoIPData {
 // later use by the discovery algorithm.
 // See https://bitbucket.org/psiphon/psiphon-circumvention-system/src/tip/Automation/psi_ops_discovery.py
 // for full details.
-func calculateDiscoveryValue(ipAddress string) int {
+func calculateDiscoveryValue(discoveryValueHMACKey, ipAddress string) int {
 	// From: psi_ops_discovery.calculate_ip_address_strategy_value:
 	//     # Mix bits from all octets of the client IP address to determine the
 	//     # bucket. An HMAC is used to prevent pre-calculation of buckets for IPs.
@@ -133,29 +178,3 @@ func calculateDiscoveryValue(ipAddress string) int {
 	hash.Write([]byte(ipAddress))
 	return int(hash.Sum(nil)[0])
 }
-
-var geoIPReader *maxminddb.Reader
-var geoIPSessionCache *cache.Cache
-var discoveryValueHMACKey string
-
-// InitGeoIP opens a GeoIP2/GeoLite2 MaxMind database and prepares
-// it for lookups.
-func InitGeoIP(config *Config) error {
-
-	discoveryValueHMACKey = config.DiscoveryValueHMACKey
-
-	if config.GeoIPDatabaseFilename != "" {
-
-		var err error
-		geoIPReader, err = maxminddb.Open(config.GeoIPDatabaseFilename)
-		if err != nil {
-			return psiphon.ContextError(err)
-		}
-
-		geoIPSessionCache = cache.New(GEOIP_SESSION_CACHE_TTL, 1*time.Minute)
-
-		log.WithContext().Info("GeoIP initialized")
-	}
-
-	return nil
-}

+ 0 - 23
psiphon/server/log.go

@@ -20,7 +20,6 @@
 package server
 
 import (
-	"fmt"
 	"io"
 	"log/syslog"
 	"os"
@@ -70,8 +69,6 @@ func NewLogWriter() *io.PipeWriter {
 }
 
 var log *ContextLogger
-var fail2BanFormat string
-var fail2BanWriter *syslog.Writer
 
 // InitLogging configures a logger according to the specified
 // config params. If not called, the default logger set by the
@@ -111,29 +108,9 @@ func InitLogging(config *Config) error {
 		},
 	}
 
-	if config.Fail2BanFormat != "" {
-		fail2BanFormat = config.Fail2BanFormat
-		fail2BanWriter, err = syslog.Dial(
-			"", "", syslog.LOG_AUTH|syslog.LOG_INFO, config.SyslogTag)
-		if err != nil {
-			return psiphon.ContextError(err)
-		}
-	}
-
 	return nil
 }
 
-// LogFail2Ban logs a message to the local syslog service AUTH
-// facility with INFO severity using the format specified by
-// config.Fail2BanFormat and the given client IP address. This
-// is for integration with fail2ban for blocking abusive
-// clients by source IP address. When set, the tag in
-// config.SyslogTag is used.
-func LogFail2Ban(clientIPAddress string) {
-	fail2BanWriter.Info(
-		fmt.Sprintf(fail2BanFormat, clientIPAddress))
-}
-
 // getSyslogPriority determines golang's syslog "priority" value
 // based on the provided config.
 func getSyslogPriority(config *Config) syslog.Priority {

+ 15 - 14
psiphon/server/meek.go

@@ -76,7 +76,7 @@ const MEEK_MAX_SESSION_ID_LENGTH = 20
 // HTTP payload traffic for a given session into net.Conn conforming Read()s and Write()s via
 // the meekConn struct.
 type MeekServer struct {
-	config        *Config
+	support       *SupportServices
 	listener      net.Listener
 	tlsConfig     *tls.Config
 	clientHandler func(clientConn net.Conn)
@@ -88,14 +88,14 @@ type MeekServer struct {
 
 // NewMeekServer initializes a new meek server.
 func NewMeekServer(
-	config *Config,
+	support *SupportServices,
 	listener net.Listener,
 	useTLS bool,
 	clientHandler func(clientConn net.Conn),
 	stopBroadcast <-chan struct{}) (*MeekServer, error) {
 
 	meekServer := &MeekServer{
-		config:        config,
+		support:       support,
 		listener:      listener,
 		clientHandler: clientHandler,
 		openConns:     new(psiphon.Conns),
@@ -104,7 +104,7 @@ func NewMeekServer(
 	}
 
 	if useTLS {
-		tlsConfig, err := makeMeekTLSConfig(config)
+		tlsConfig, err := makeMeekTLSConfig(support)
 		if err != nil {
 			return nil, psiphon.ContextError(err)
 		}
@@ -199,8 +199,8 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
 		return
 	}
 
-	if len(server.config.MeekProhibitedHeaders) > 0 {
-		for _, header := range server.config.MeekProhibitedHeaders {
+	if len(server.support.Config.MeekProhibitedHeaders) > 0 {
+		for _, header := range server.support.Config.MeekProhibitedHeaders {
 			value := request.Header.Get(header)
 			if header != "" {
 				log.WithContextFields(LogFields{
@@ -284,7 +284,7 @@ func (server *MeekServer) getSession(
 	// The session is new (or expired). Treat the cookie value as a new meek
 	// cookie, extract the payload, and create a new session.
 
-	payloadJSON, err := getMeekCookiePayload(server.config, meekCookie.Value)
+	payloadJSON, err := getMeekCookiePayload(server.support, meekCookie.Value)
 	if err != nil {
 		return "", nil, psiphon.ContextError(err)
 	}
@@ -309,8 +309,8 @@ func (server *MeekServer) getSession(
 
 	clientIP := strings.Split(request.RemoteAddr, ":")[0]
 
-	if len(server.config.MeekProxyForwardedForHeaders) > 0 {
-		for _, header := range server.config.MeekProxyForwardedForHeaders {
+	if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 {
+		for _, header := range server.support.Config.MeekProxyForwardedForHeaders {
 			value := request.Header.Get(header)
 			if len(value) > 0 {
 				// Some headers, such as X-Forwarded-For, are a comma-separated
@@ -451,10 +451,10 @@ func (session *meekSession) expired() bool {
 // Currently, this config is optimized for fronted meek where the nature
 // of the connection is non-circumvention; it's optimized for performance
 // assuming the peer is an uncensored CDN.
-func makeMeekTLSConfig(config *Config) (*tls.Config, error) {
+func makeMeekTLSConfig(support *SupportServices) (*tls.Config, error) {
 
 	certificate, privateKey, err := GenerateWebServerCertificate(
-		config.MeekCertificateCommonName)
+		support.Config.MeekCertificateCommonName)
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}
@@ -501,7 +501,7 @@ func makeMeekTLSConfig(config *Config) (*tls.Config, error) {
 
 // getMeekCookiePayload extracts the payload from a meek cookie. The cookie
 // paylod is base64 encoded, obfuscated, and NaCl encrypted.
-func getMeekCookiePayload(config *Config, cookieValue string) ([]byte, error) {
+func getMeekCookiePayload(support *SupportServices, cookieValue string) ([]byte, error) {
 	decodedValue, err := base64.StdEncoding.DecodeString(cookieValue)
 	if err != nil {
 		return nil, psiphon.ContextError(err)
@@ -516,7 +516,7 @@ func getMeekCookiePayload(config *Config, cookieValue string) ([]byte, error) {
 
 	obfuscator, err := psiphon.NewServerObfuscator(
 		reader,
-		&psiphon.ObfuscatorConfig{Keyword: config.MeekObfuscatedKey})
+		&psiphon.ObfuscatorConfig{Keyword: support.Config.MeekObfuscatedKey})
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}
@@ -532,7 +532,8 @@ func getMeekCookiePayload(config *Config, cookieValue string) ([]byte, error) {
 	var nonce [24]byte
 	var privateKey, ephemeralPublicKey [32]byte
 
-	decodedPrivateKey, err := base64.StdEncoding.DecodeString(config.MeekCookieEncryptionPrivateKey)
+	decodedPrivateKey, err := base64.StdEncoding.DecodeString(
+		support.Config.MeekCookieEncryptionPrivateKey)
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}

+ 18 - 16
psiphon/server/psinet/psinet.go

@@ -30,9 +30,10 @@ import (
 )
 
 // Database serves Psiphon API data requests. It's safe for
-// concurrent usage.
+// concurrent usage. The Reload function supports hot reloading
+// of Psiphon network data while the server is running.
 type Database struct {
-	mutex sync.RWMutex
+	sync.RWMutex
 
 	// TODO: implement
 }
@@ -43,7 +44,7 @@ func NewDatabase(filename string) (*Database, error) {
 
 	database := &Database{}
 
-	err := database.Load(filename)
+	err := database.Reload(filename)
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}
@@ -51,14 +52,15 @@ func NewDatabase(filename string) (*Database, error) {
 	return database, nil
 }
 
-// Load [re]initializes the Database with the Psiphon network data
+// Reload [re]initializes the Database with the Psiphon network data
 // in the specified file. This function obtains a write lock on
 // the database, blocking all readers.
 // The input "" is valid and initializes a functional Database
-// with no data.
-func (db *Database) Load(filename string) error {
-	db.mutex.Lock()
-	defer db.mutex.Unlock()
+// with no data. When Reload fails, the previous Database state is
+// retained.
+func (db *Database) Reload(filename string) error {
+	db.Lock()
+	defer db.Unlock()
 
 	// TODO: implement
 
@@ -68,8 +70,8 @@ func (db *Database) Load(filename string) error {
 // GetHomepages returns a list of  home pages for the specified sponsor,
 // region, and platform.
 func (db *Database) GetHomepages(sponsorID, clientRegion, clientPlatform string) []string {
-	db.mutex.RLock()
-	defer db.mutex.RUnlock()
+	db.RLock()
+	defer db.RUnlock()
 
 	// TODO: implement
 
@@ -80,8 +82,8 @@ func (db *Database) GetHomepages(sponsorID, clientRegion, clientPlatform string)
 // indicated for the specified client current version. The result is "" when
 // no upgrade is available.
 func (db *Database) GetUpgradeClientVersion(clientVersion, clientPlatform string) string {
-	db.mutex.RLock()
-	defer db.mutex.RUnlock()
+	db.RLock()
+	defer db.RUnlock()
 
 	// TODO: implement
 
@@ -91,8 +93,8 @@ func (db *Database) GetUpgradeClientVersion(clientVersion, clientPlatform string
 // GetHttpsRequestRegexes returns bytes transferred stats regexes for the
 // specified sponsor.
 func (db *Database) GetHttpsRequestRegexes(sponsorID string) []map[string]string {
-	db.mutex.RLock()
-	defer db.mutex.RUnlock()
+	db.RLock()
+	defer db.RUnlock()
 
 	return make([]map[string]string, 0)
 }
@@ -100,8 +102,8 @@ func (db *Database) GetHttpsRequestRegexes(sponsorID string) []map[string]string
 // DiscoverServers selects new encoded server entries to be "discovered" by
 // the client, using the discoveryValue as the input into the discovery algorithm.
 func (db *Database) DiscoverServers(propagationChannelID string, discoveryValue int) []string {
-	db.mutex.RLock()
-	defer db.mutex.RUnlock()
+	db.RLock()
+	defer db.RUnlock()
 
 	// TODO: implement
 

+ 9 - 8
psiphon/server/server_test.go

@@ -93,7 +93,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// create a server
 
-	serverConfigFileContents, serverEntryFileContents, err := GenerateConfig(
+	serverConfigJSON, _, encodedServerEntry, err := GenerateConfig(
 		&GenerateConfigParams{
 			ServerIPAddress:      "127.0.0.1",
 			EnableSSHAPIRequests: runConfig.enableSSHAPIRequests,
@@ -107,9 +107,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	// customize server config
 
 	var serverConfig interface{}
-	json.Unmarshal(serverConfigFileContents, &serverConfig)
+	json.Unmarshal(serverConfigJSON, &serverConfig)
 	serverConfig.(map[string]interface{})["GeoIPDatabaseFilename"] = ""
-	serverConfigFileContents, _ = json.Marshal(serverConfig)
+	serverConfig.(map[string]interface{})["TrafficRulesFilename"] = ""
+	serverConfigJSON, _ = json.Marshal(serverConfig)
 
 	// run server
 
@@ -117,7 +118,7 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	serverWaitGroup.Add(1)
 	go func() {
 		defer serverWaitGroup.Done()
-		err := RunServices([][]byte{serverConfigFileContents})
+		err := RunServices(serverConfigJSON)
 		if err != nil {
 			// TODO: wrong goroutine for t.FatalNow()
 			t.Fatalf("error running server: %s", err)
@@ -149,23 +150,23 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 
 	// TODO: currently, TargetServerEntry only works with one tunnel
 	numTunnels := 1
-	localHTTPProxyPort := 8080
+	localHTTPProxyPort := 8081
 	establishTunnelPausePeriodSeconds := 1
 
 	// Note: calling LoadConfig ensures all *int config fields are initialized
-	configJson := `
+	clientConfigJSON := `
 	{
 	"ClientVersion":                     "0",
 	"PropagationChannelId":              "0",
 	"SponsorId":                         "0"
 	}`
-	clientConfig, _ := psiphon.LoadConfig([]byte(configJson))
+	clientConfig, _ := psiphon.LoadConfig([]byte(clientConfigJSON))
 
 	clientConfig.ConnectionWorkerPoolSize = numTunnels
 	clientConfig.TunnelPoolSize = numTunnels
 	clientConfig.DisableRemoteServerListFetcher = true
 	clientConfig.EstablishTunnelPausePeriodSeconds = &establishTunnelPausePeriodSeconds
-	clientConfig.TargetServerEntry = string(serverEntryFileContents)
+	clientConfig.TargetServerEntry = string(encodedServerEntry)
 	clientConfig.TunnelProtocol = runConfig.tunnelProtocol
 	clientConfig.LocalHttpProxyPort = localHTTPProxyPort
 

+ 85 - 19
psiphon/server/services.go

@@ -38,9 +38,9 @@ import (
 // RunServices initializes support functions including logging and GeoIP services;
 // and then starts the server components and runs them until os.Interrupt or
 // os.Kill signals are received. The config determines which components are run.
-func RunServices(encodedConfigs [][]byte) error {
+func RunServices(configJSON []byte) error {
 
-	config, err := LoadConfig(encodedConfigs)
+	config, err := LoadConfig(configJSON)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("load config failed")
 		return psiphon.ContextError(err)
@@ -52,15 +52,9 @@ func RunServices(encodedConfigs [][]byte) error {
 		return psiphon.ContextError(err)
 	}
 
-	err = InitGeoIP(config)
+	supportServices, err := NewSupportServices(config)
 	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Error("init GeoIP failed")
-		return psiphon.ContextError(err)
-	}
-
-	psinetDatabase, err := psinet.NewDatabase(config.PsinetDatabaseFilename)
-	if err != nil {
-		log.WithContextFields(LogFields{"error": err}).Error("init PsinetDatabase failed")
+		log.WithContextFields(LogFields{"error": err}).Error("init support services failed")
 		return psiphon.ContextError(err)
 	}
 
@@ -68,7 +62,7 @@ func RunServices(encodedConfigs [][]byte) error {
 	shutdownBroadcast := make(chan struct{})
 	errors := make(chan error)
 
-	tunnelServer, err := NewTunnelServer(config, psinetDatabase, shutdownBroadcast)
+	tunnelServer, err := NewTunnelServer(supportServices, shutdownBroadcast)
 	if err != nil {
 		log.WithContextFields(LogFields{"error": err}).Error("init tunnel server failed")
 		return psiphon.ContextError(err)
@@ -85,7 +79,7 @@ func RunServices(encodedConfigs [][]byte) error {
 				case <-shutdownBroadcast:
 					return
 				case <-ticker.C:
-					logLoad(tunnelServer)
+					logServerLoad(tunnelServer)
 				}
 			}
 		}()
@@ -95,7 +89,7 @@ func RunServices(encodedConfigs [][]byte) error {
 		waitGroup.Add(1)
 		go func() {
 			defer waitGroup.Done()
-			err := RunWebServer(config, psinetDatabase, shutdownBroadcast)
+			err := RunWebServer(supportServices, shutdownBroadcast)
 			select {
 			case errors <- err:
 			default:
@@ -119,17 +113,23 @@ func RunServices(encodedConfigs [][]byte) error {
 	systemStopSignal := make(chan os.Signal, 1)
 	signal.Notify(systemStopSignal, os.Interrupt, os.Kill)
 
-	// SIGUSR1 triggers a load log
-	logLoadSignal := make(chan os.Signal, 1)
-	signal.Notify(logLoadSignal, syscall.SIGUSR1)
+	// SIGUSR1 triggers a reload of support services
+	reloadSupportServicesSignal := make(chan os.Signal, 1)
+	signal.Notify(reloadSupportServicesSignal, syscall.SIGUSR1)
+
+	// SIGUSR2 triggers an immediate load log
+	logServerLoadSignal := make(chan os.Signal, 1)
+	signal.Notify(logServerLoadSignal, syscall.SIGUSR2)
 
 	err = nil
 
 loop:
 	for {
 		select {
-		case <-logLoadSignal:
-			logLoad(tunnelServer)
+		case <-reloadSupportServicesSignal:
+			supportServices.Reload()
+		case <-logServerLoadSignal:
+			logServerLoad(tunnelServer)
 		case <-systemStopSignal:
 			log.WithContext().Info("shutdown by system")
 			break loop
@@ -145,7 +145,7 @@ loop:
 	return err
 }
 
-func logLoad(server *TunnelServer) {
+func logServerLoad(server *TunnelServer) {
 
 	// golang runtime stats
 	var memStats runtime.MemStats
@@ -166,3 +166,69 @@ func logLoad(server *TunnelServer) {
 
 	log.WithContextFields(fields).Info("load")
 }
+
+// SupportServices carries common and shared data components
+// across different server components. SupportServices implements a
+// hot reload of traffic rules, psinet database, and geo IP database
+// components, which allows these data components to be refreshed
+// without restarting the server process.
+type SupportServices struct {
+	Config          *Config
+	TrafficRulesSet *TrafficRulesSet
+	PsinetDatabase  *psinet.Database
+	GeoIPService    *GeoIPService
+}
+
+// NewSupportServices initializes a new SupportServices.
+func NewSupportServices(config *Config) (*SupportServices, error) {
+	trafficRulesSet, err := NewTrafficRulesSet(config.TrafficRulesFilename)
+	if err != nil {
+		return nil, psiphon.ContextError(err)
+	}
+
+	psinetDatabase, err := psinet.NewDatabase(config.PsinetDatabaseFilename)
+	if err != nil {
+		return nil, psiphon.ContextError(err)
+	}
+
+	geoIPService, err := NewGeoIPService(
+		config.GeoIPDatabaseFilename, config.DiscoveryValueHMACKey)
+	if err != nil {
+		return nil, psiphon.ContextError(err)
+	}
+
+	return &SupportServices{
+		Config:          config,
+		TrafficRulesSet: trafficRulesSet,
+		PsinetDatabase:  psinetDatabase,
+		GeoIPService:    geoIPService,
+	}, nil
+}
+
+// Reload reinitializes traffic rules, psinet database, and geo IP database
+// components. If any component fails to reload, an error is logged and
+// Reload proceeds, using the previous state of the component.
+//
+// Note: reload of traffic rules currently doesn't apply to existing,
+// established clients.
+//
+func (support *SupportServices) Reload() {
+
+	err := support.TrafficRulesSet.Reload(support.Config.TrafficRulesFilename)
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Error("reload traffic rules failed")
+		// Keep running with previous state of support.TrafficRulesSet
+	}
+
+	err = support.PsinetDatabase.Reload(support.Config.PsinetDatabaseFilename)
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Error("reload psinet database failed")
+		// Keep running with previous state of support.PsinetDatabase
+	}
+
+	err = support.GeoIPService.ReloadDatabase(support.Config.GeoIPDatabaseFilename)
+	if err != nil {
+		log.WithContextFields(LogFields{"error": err}).Error("reload GeoIP database failed")
+		// Keep running with previous state of support.GeoIPService
+	}
+}

+ 195 - 0
psiphon/server/trafficRules.go

@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2016, Psiphon Inc.
+ * All rights reserved.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ */
+
+package server
+
+import (
+	"encoding/json"
+	"io/ioutil"
+	"strings"
+	"sync"
+
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
+)
+
+// TrafficRulesSet represents the various traffic rules to
+// apply to Psiphon client tunnels. The Reload function supports
+// hot reloading of rules data while the server is running.
+type TrafficRulesSet struct {
+	sync.RWMutex
+
+	// DefaultRules specifies the traffic rules to be used when no
+	// regional-specific rules are set or apply to a particular
+	// client.
+	DefaultRules TrafficRules
+
+	// RegionalRules specifies the traffic rules for particular client
+	// regions (countries) as determined by GeoIP lookup of the client
+	// IP address. The key for each regional traffic rule entry is one
+	// or more space delimited ISO 3166-1 alpha-2 country codes.
+	RegionalRules map[string]TrafficRules
+}
+
+// RateLimits specify the rate limits for tunneled data transfer
+// between an individual client and the server.
+type RateLimits struct {
+
+	// DownstreamUnlimitedBytes specifies the number of downstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	DownstreamUnlimitedBytes int64
+
+	// DownstreamBytesPerSecond specifies a rate limit for downstream
+	// data transfer. The default, 0, is no limit.
+	DownstreamBytesPerSecond int
+
+	// UpstreamUnlimitedBytes specifies the number of upstream
+	// bytes to transfer, approximately, before starting rate
+	// limiting.
+	UpstreamUnlimitedBytes int64
+
+	// UpstreamBytesPerSecond specifies a rate limit for upstream
+	// data transfer. The default, 0, is no limit.
+	UpstreamBytesPerSecond int
+}
+
+// TrafficRules specify the limits placed on client traffic.
+type TrafficRules struct {
+	// DefaultLimits are the rate limits to be applied when
+	// no protocol-specific rates are set.
+	DefaultLimits RateLimits
+
+	// ProtocolLimits specifies the rate limits for particular
+	// tunnel protocols. The key for each rate limit entry is one
+	// or more space delimited Psiphon tunnel protocol names. Valid
+	// tunnel protocols includes the same list as for
+	// TunnelProtocolPorts.
+	ProtocolLimits map[string]RateLimits
+
+	// IdleTCPPortForwardTimeoutMilliseconds is the timeout period
+	// after which idle (no bytes flowing in either direction)
+	// client TCP port forwards are preemptively closed.
+	// The default, 0, is no idle timeout.
+	IdleTCPPortForwardTimeoutMilliseconds int
+
+	// IdleUDPPortForwardTimeoutMilliseconds is the timeout period
+	// after which idle (no bytes flowing in either direction)
+	// client UDP port forwards are preemptively closed.
+	// The default, 0, is no idle timeout.
+	IdleUDPPortForwardTimeoutMilliseconds int
+
+	// MaxTCPPortForwardCount is the maximum number of TCP port
+	// forwards each client may have open concurrently.
+	// The default, 0, is no maximum.
+	MaxTCPPortForwardCount int
+
+	// MaxUDPPortForwardCount is the maximum number of UDP port
+	// forwards each client may have open concurrently.
+	// The default, 0, is no maximum.
+	MaxUDPPortForwardCount int
+
+	// AllowTCPPorts specifies a whitelist of TCP ports that
+	// are permitted for port forwarding. When set, only ports
+	// in the list are accessible to clients.
+	AllowTCPPorts []int
+
+	// AllowUDPPorts specifies a whitelist of UDP ports that
+	// are permitted for port forwarding. When set, only ports
+	// in the list are accessible to clients.
+	AllowUDPPorts []int
+
+	// DenyTCPPorts specifies a blacklist of TCP ports that
+	// are not permitted for port forwarding. When set, the
+	// ports in the list are inaccessible to clients.
+	DenyTCPPorts []int
+
+	// DenyUDPPorts specifies a blacklist of UDP ports that
+	// are not permitted for port forwarding. When set, the
+	// ports in the list are inaccessible to clients.
+	DenyUDPPorts []int
+}
+
+// NewTrafficRulesSet initializes a TrafficRulesSet with
+// the rules data in the specified config file.
+func NewTrafficRulesSet(ruleSetFilename string) (*TrafficRulesSet, error) {
+	set := &TrafficRulesSet{}
+	return set, set.Reload(ruleSetFilename)
+}
+
+// Reload [re]initializes the TrafficRulesSet with the rules data
+// in the specified file. This function obtains a write lock on
+// the database, blocking all readers. When Reload fails, the previous
+// state is retained.
+func (set *TrafficRulesSet) Reload(ruleSetFilename string) error {
+	set.Lock()
+	defer set.Unlock()
+
+	if ruleSetFilename == "" {
+		// No traffic rules filename in the config
+		return nil
+	}
+
+	configJSON, err := ioutil.ReadFile(ruleSetFilename)
+	if err != nil {
+		return psiphon.ContextError(err)
+	}
+
+	var newSet TrafficRulesSet
+	err = json.Unmarshal(configJSON, &newSet)
+	if err != nil {
+		return psiphon.ContextError(err)
+	}
+
+	*set = newSet
+
+	return nil
+}
+
+// GetTrafficRules looks up the traffic rules for the specified country. If there
+// are no regional TrafficRules for the country, default TrafficRules are returned.
+func (set *TrafficRulesSet) GetTrafficRules(clientCountryCode string) TrafficRules {
+	set.RLock()
+	defer set.RUnlock()
+
+	// TODO: faster lookup?
+	for countryCodes, trafficRules := range set.RegionalRules {
+		for _, countryCode := range strings.Split(countryCodes, " ") {
+			if countryCode == clientCountryCode {
+				return trafficRules
+			}
+		}
+	}
+	return set.DefaultRules
+}
+
+// GetRateLimits looks up the rate limits for the specified tunnel protocol.
+// If there are no specific RateLimits for the protocol, default RateLimits are
+// returned.
+func (rules *TrafficRules) GetRateLimits(clientTunnelProtocol string) RateLimits {
+
+	// TODO: faster lookup?
+	for tunnelProtocols, rateLimits := range rules.ProtocolLimits {
+		for _, tunnelProtocol := range strings.Split(tunnelProtocols, " ") {
+			if tunnelProtocol == clientTunnelProtocol {
+				return rateLimits
+			}
+		}
+	}
+	return rules.DefaultLimits
+}

+ 32 - 32
psiphon/server/tunnelServer.go

@@ -31,7 +31,6 @@ import (
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 	"golang.org/x/crypto/ssh"
 )
 
@@ -44,7 +43,6 @@ import (
 // and meek protocols, which provide further circumvention
 // capabilities.
 type TunnelServer struct {
-	config            *Config
 	runWaitGroup      *sync.WaitGroup
 	listenerError     chan error
 	shutdownBroadcast <-chan struct{}
@@ -53,18 +51,15 @@ type TunnelServer struct {
 
 // NewTunnelServer initializes a new tunnel server.
 func NewTunnelServer(
-	config *Config,
-	psinetDatabase *psinet.Database,
+	support *SupportServices,
 	shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
 
-	sshServer, err := newSSHServer(
-		config, psinetDatabase, shutdownBroadcast)
+	sshServer, err := newSSHServer(support, shutdownBroadcast)
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}
 
 	return &TunnelServer{
-		config:            config,
 		runWaitGroup:      new(sync.WaitGroup),
 		listenerError:     make(chan error),
 		shutdownBroadcast: shutdownBroadcast,
@@ -105,15 +100,18 @@ func (server *TunnelServer) Run() error {
 		tunnelProtocol string
 	}
 
+	// TODO: should TunnelServer hold its own support pointer?
+	support := server.sshServer.support
+
 	// First bind all listeners; once all are successful,
 	// start accepting connections on each.
 
 	var listeners []*sshListener
 
-	for tunnelProtocol, listenPort := range server.config.TunnelProtocolPorts {
+	for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
 
 		localAddress := fmt.Sprintf(
-			"%s:%d", server.config.ServerIPAddress, listenPort)
+			"%s:%d", support.Config.ServerIPAddress, listenPort)
 
 		listener, err := net.Listen("tcp", localAddress)
 		if err != nil {
@@ -183,8 +181,7 @@ func (server *TunnelServer) Run() error {
 type sshClientID uint64
 
 type sshServer struct {
-	config            *Config
-	psinetDatabase    *psinet.Database
+	support           *SupportServices
 	shutdownBroadcast <-chan struct{}
 	sshHostKey        ssh.Signer
 	nextClientID      sshClientID
@@ -194,11 +191,10 @@ type sshServer struct {
 }
 
 func newSSHServer(
-	config *Config,
-	psinetDatabase *psinet.Database,
+	support *SupportServices,
 	shutdownBroadcast <-chan struct{}) (*sshServer, error) {
 
-	privateKey, err := ssh.ParseRawPrivateKey([]byte(config.SSHPrivateKey))
+	privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
 	if err != nil {
 		return nil, psiphon.ContextError(err)
 	}
@@ -210,8 +206,7 @@ func newSSHServer(
 	}
 
 	return &sshServer{
-		config:            config,
-		psinetDatabase:    psinetDatabase,
+		support:           support,
 		shutdownBroadcast: shutdownBroadcast,
 		sshHostKey:        signer,
 		nextClientID:      1,
@@ -241,7 +236,7 @@ func (sshServer *sshServer) runListener(
 		psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol) {
 
 		meekServer, err := NewMeekServer(
-			sshServer.config,
+			sshServer.support,
 			listener,
 			psiphon.TunnelProtocolUsesMeekHTTPS(tunnelProtocol),
 			handleClient,
@@ -355,13 +350,16 @@ func (sshServer *sshServer) stopClients() {
 
 func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.Conn) {
 
-	geoIPData := GeoIPLookup(psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
+	geoIPData := sshServer.support.GeoIPService.Lookup(
+		psiphon.IPAddressFromAddr(clientConn.RemoteAddr()))
+
+	// TODO: apply reload of TrafficRulesSet to existing clients
 
 	sshClient := newSshClient(
 		sshServer,
 		tunnelProtocol,
 		geoIPData,
-		sshServer.config.GetTrafficRules(geoIPData.Country))
+		sshServer.support.TrafficRulesSet.GetTrafficRules(geoIPData.Country))
 
 	// Wrap the base client connection with an ActivityMonitoredConn which will
 	// terminate the connection if no data is received before the deadline. This
@@ -411,7 +409,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 		sshServerConfig := &ssh.ServerConfig{
 			PasswordCallback: sshClient.passwordCallback,
 			AuthLogCallback:  sshClient.authLogCallback,
-			ServerVersion:    sshServer.config.SSHServerVersion,
+			ServerVersion:    sshServer.support.Config.SSHServerVersion,
 		}
 		sshServerConfig.AddHostKey(sshServer.sshHostKey)
 
@@ -425,7 +423,7 @@ func (sshServer *sshServer) handleClient(tunnelProtocol string, clientConn net.C
 			conn, result.err = psiphon.NewObfuscatedSshConn(
 				psiphon.OBFUSCATION_CONN_MODE_SERVER,
 				clientConn,
-				sshServer.config.ObfuscatedSSHKey)
+				sshServer.support.Config.ObfuscatedSSHKey)
 			if result.err != nil {
 				result.err = psiphon.ContextError(result.err)
 			}
@@ -538,15 +536,15 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 		}
 	}
 
-	if !isHexDigits(sshClient.sshServer.config, sshPasswordPayload.SessionId) {
+	if !isHexDigits(sshClient.sshServer.support, sshPasswordPayload.SessionId) {
 		return nil, psiphon.ContextError(fmt.Errorf("invalid session ID for %q", conn.User()))
 	}
 
 	userOk := (subtle.ConstantTimeCompare(
-		[]byte(conn.User()), []byte(sshClient.sshServer.config.SSHUserName)) == 1)
+		[]byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
 
 	passwordOk := (subtle.ConstantTimeCompare(
-		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.config.SSHPassword)) == 1)
+		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
 
 	if !userOk || !passwordOk {
 		return nil, psiphon.ContextError(fmt.Errorf("invalid password for %q", conn.User()))
@@ -561,20 +559,23 @@ func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []b
 
 	// Store the GeoIP data associated with the session ID. This makes the GeoIP data
 	// available to the web server for web transport Psiphon API requests.
-	SetGeoIPSessionCache(psiphonSessionID, geoIPData)
+	sshClient.sshServer.support.GeoIPService.SetSessionCache(
+		psiphonSessionID, geoIPData)
 
 	return nil, nil
 }
 
 func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
 	if err != nil {
-		if sshClient.sshServer.config.UseFail2Ban() {
+		logFields := LogFields{"error": err, "method": method}
+		if sshClient.sshServer.support.Config.UseFail2Ban() {
 			clientIPAddress := psiphon.IPAddressFromAddr(conn.RemoteAddr())
 			if clientIPAddress != "" {
-				LogFail2Ban(clientIPAddress)
+				logFields["fail2ban"] = fmt.Sprintf(
+					sshClient.sshServer.support.Config.Fail2BanFormat, clientIPAddress)
 			}
 		}
-		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
+		log.WithContextFields(LogFields{"error": err, "method": method}).Error("authentication failed")
 	} else {
 		log.WithContextFields(LogFields{"error": err, "method": method}).Debug("authentication success")
 	}
@@ -633,8 +634,7 @@ func (sshClient *sshClient) runClient(
 
 			// requests are processed serially; responses must be sent in request order.
 			responsePayload, err := sshAPIRequestHandler(
-				sshClient.sshServer.config,
-				sshClient.sshServer.psinetDatabase,
+				sshClient.sshServer.support,
 				sshClient.geoIPData,
 				request.Type,
 				request.Payload)
@@ -697,8 +697,8 @@ func (sshClient *sshClient) handleNewPortForwardChannel(newChannel ssh.NewChanne
 
 	// Intercept TCP port forwards to a specified udpgw server and handle directly.
 	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
-	isUDPChannel := sshClient.sshServer.config.UDPInterceptUdpgwServerAddress != "" &&
-		sshClient.sshServer.config.UDPInterceptUdpgwServerAddress ==
+	isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
 			fmt.Sprintf("%s:%d",
 				directTcpipExtraData.HostToConnect,
 				directTcpipExtraData.PortToConnect)

+ 2 - 2
psiphon/server/udp.go

@@ -256,10 +256,10 @@ func (mux *udpPortForwardMultiplexer) run() {
 func (mux *udpPortForwardMultiplexer) transparentDNSAddress(
 	dialIP net.IP, dialPort int) (net.IP, int) {
 
-	if mux.sshClient.sshServer.config.UDPForwardDNSServerAddress != "" {
+	if mux.sshClient.sshServer.support.Config.UDPForwardDNSServerAddress != "" {
 		// Note: UDPForwardDNSServerAddress is validated in LoadConfig
 		host, portStr, _ := net.SplitHostPort(
-			mux.sshClient.sshServer.config.UDPForwardDNSServerAddress)
+			mux.sshClient.sshServer.support.Config.UDPForwardDNSServerAddress)
 		dialIP = net.ParseIP(host)
 		dialPort, _ = strconv.Atoi(portStr)
 	}

+ 14 - 19
psiphon/server/webServer.go

@@ -30,13 +30,11 @@ import (
 	"sync"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
-	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server/psinet"
 )
 
 type webServer struct {
-	serveMux       *http.ServeMux
-	config         *Config
-	psinetDatabase *psinet.Database
+	support  *SupportServices
+	serveMux *http.ServeMux
 }
 
 // RunWebServer runs a web server which supports tunneled and untunneled
@@ -54,13 +52,11 @@ type webServer struct {
 // compatible with older clients.
 //
 func RunWebServer(
-	config *Config,
-	psinetDatabase *psinet.Database,
+	support *SupportServices,
 	shutdownBroadcast <-chan struct{}) error {
 
 	webServer := &webServer{
-		config:         config,
-		psinetDatabase: psinetDatabase,
+		support: support,
 	}
 
 	serveMux := http.NewServeMux()
@@ -70,8 +66,8 @@ func RunWebServer(
 	serveMux.HandleFunc("/client_verification", webServer.clientVerificationHandler)
 
 	certificate, err := tls.X509KeyPair(
-		[]byte(config.WebServerCertificate),
-		[]byte(config.WebServerPrivateKey))
+		[]byte(support.Config.WebServerCertificate),
+		[]byte(support.Config.WebServerPrivateKey))
 	if err != nil {
 		return psiphon.ContextError(err)
 	}
@@ -96,7 +92,9 @@ func RunWebServer(
 	}
 
 	listener, err := net.Listen(
-		"tcp", fmt.Sprintf("%s:%d", config.ServerIPAddress, config.WebServerPort))
+		"tcp", fmt.Sprintf("%s:%d",
+			support.Config.ServerIPAddress,
+			support.Config.WebServerPort))
 	if err != nil {
 		return psiphon.ContextError(err)
 	}
@@ -188,7 +186,7 @@ func (webServer *webServer) lookupGeoIPData(params requestJSONObject) GeoIPData
 		return NewGeoIPData()
 	}
 
-	return GetGeoIPSessionCache(clientSessionID)
+	return webServer.support.GeoIPService.GetSessionCache(clientSessionID)
 }
 
 func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Request) {
@@ -198,10 +196,7 @@ func (webServer *webServer) handshakeHandler(w http.ResponseWriter, r *http.Requ
 	var responsePayload []byte
 	if err == nil {
 		responsePayload, err = handshakeAPIRequestHandler(
-			webServer.config,
-			webServer.psinetDatabase,
-			webServer.lookupGeoIPData(params),
-			params)
+			webServer.support, webServer.lookupGeoIPData(params), params)
 	}
 
 	if err != nil {
@@ -227,7 +222,7 @@ func (webServer *webServer) connectedHandler(w http.ResponseWriter, r *http.Requ
 	var responsePayload []byte
 	if err == nil {
 		responsePayload, err = connectedAPIRequestHandler(
-			webServer.config, webServer.lookupGeoIPData(params), params)
+			webServer.support, webServer.lookupGeoIPData(params), params)
 	}
 
 	if err != nil {
@@ -246,7 +241,7 @@ func (webServer *webServer) statusHandler(w http.ResponseWriter, r *http.Request
 
 	if err == nil {
 		_, err = statusAPIRequestHandler(
-			webServer.config, webServer.lookupGeoIPData(params), params)
+			webServer.support, webServer.lookupGeoIPData(params), params)
 	}
 
 	if err != nil {
@@ -264,7 +259,7 @@ func (webServer *webServer) clientVerificationHandler(w http.ResponseWriter, r *
 
 	if err == nil {
 		_, err = clientVerificationAPIRequestHandler(
-			webServer.config, webServer.lookupGeoIPData(params), params)
+			webServer.support, webServer.lookupGeoIPData(params), params)
 	}
 
 	if err != nil {