Просмотр исходного кода

Fix: race conditions due to concurrent net.Conn Read/Write calls

Rod Hynes 11 лет назад
Родитель
Сommit
86511863d1
1 измененных файлов с 25 добавлено и 13 удалено
  1. 25 13
      psiphon/stats_conn.go

+ 25 - 13
psiphon/stats_conn.go

@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2014, Psiphon Inc.
+ * Copyright (c) 2015, Psiphon Inc.
  * All rights reserved.
  *
  * This program is free software: you can redistribute it and/or modify
@@ -30,16 +30,20 @@ Assumption: Enough of the first HTTP will be present in the first Write() call
 		- If this turns out to not be generally true we will need to add buffering.
 */
 
-import "net"
+import (
+	"net"
+	"sync/atomic"
+)
 
 // StatsConn is to be used as an intermediate link in a chain of net.Conn objects.
 // It inspects requests and responses and derives stats from them.
 type StatsConn struct {
 	net.Conn
-	serverID   string
-	hostname   string
-	firstWrite bool
-	regexps    *Regexps
+	serverID       string
+	firstWrite     int32
+	hostnameParsed int32
+	hostname       string
+	regexps        *Regexps
 }
 
 // NewStatsConn creates a StatsConn. serverID can be anything that uniquely
@@ -47,10 +51,11 @@ type StatsConn struct {
 // the accumulated stats.
 func NewStatsConn(nextConn net.Conn, serverID string, regexps *Regexps) *StatsConn {
 	return &StatsConn{
-		Conn:       nextConn,
-		serverID:   serverID,
-		firstWrite: true,
-		regexps:    regexps,
+		Conn:           nextConn,
+		serverID:       serverID,
+		firstWrite:     1,
+		hostnameParsed: 0,
+		regexps:        regexps,
 	}
 }
 
@@ -65,14 +70,14 @@ func (conn *StatsConn) Write(buffer []byte) (n int, err error) {
 	if n > 0 {
 		// If this is the first request, try to determine the hostname to associate
 		// with this connection.
-		if conn.firstWrite {
-			conn.firstWrite = false
+		if atomic.CompareAndSwapInt32(&conn.firstWrite, 0, 1) {
 
 			hostname, ok := getHostname(buffer)
 			if ok {
 				// Get the hostname value that will be stored in stats by
 				// regexing the real hostname.
 				conn.hostname = regexHostname(hostname, conn.regexps)
+				atomic.StoreInt32(&conn.hostnameParsed, 1)
 			}
 		}
 
@@ -90,11 +95,18 @@ func (conn *StatsConn) Write(buffer []byte) (n int, err error) {
 func (conn *StatsConn) Read(buffer []byte) (n int, err error) {
 	n, err = conn.Conn.Read(buffer)
 
+	var hostname string
+	if 1 == atomic.LoadInt32(&conn.hostnameParsed) {
+		hostname = conn.hostname
+	} else {
+		hostname = ""
+	}
+
 	// Count bytes without checking the error condition. It could happen that the
 	// buffer was partially read and then an error occurred.
 	recordStat(&statsUpdate{
 		conn.serverID,
-		conn.hostname,
+		hostname,
 		0,
 		int64(n)})