utils.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package common
  20. import (
  21. "bytes"
  22. "compress/zlib"
  23. "crypto/rand"
  24. "encoding/base64"
  25. "encoding/hex"
  26. "errors"
  27. "fmt"
  28. "io/ioutil"
  29. "math/big"
  30. "runtime"
  31. "strings"
  32. "time"
  33. )
  34. // Contains is a helper function that returns true
  35. // if the target string is in the list.
  36. func Contains(list []string, target string) bool {
  37. for _, listItem := range list {
  38. if listItem == target {
  39. return true
  40. }
  41. }
  42. return false
  43. }
  44. // FlipCoin is a helper function that randomly
  45. // returns true or false. If the underlying random
  46. // number generator fails, FlipCoin still returns
  47. // a result.
  48. func FlipCoin() bool {
  49. randomInt, _ := MakeSecureRandomInt(2)
  50. return randomInt == 1
  51. }
  52. // MakeSecureRandomInt is a helper function that wraps
  53. // MakeSecureRandomInt64.
  54. func MakeSecureRandomInt(max int) (int, error) {
  55. randomInt, err := MakeSecureRandomInt64(int64(max))
  56. return int(randomInt), err
  57. }
  58. // MakeSecureRandomInt64 is a helper function that wraps
  59. // crypto/rand.Int, which returns a uniform random value in [0, max).
  60. func MakeSecureRandomInt64(max int64) (int64, error) {
  61. randomInt, err := rand.Int(rand.Reader, big.NewInt(max))
  62. if err != nil {
  63. return 0, ContextError(err)
  64. }
  65. return randomInt.Int64(), nil
  66. }
  67. // MakeSecureRandomBytes is a helper function that wraps
  68. // crypto/rand.Read.
  69. func MakeSecureRandomBytes(length int) ([]byte, error) {
  70. randomBytes := make([]byte, length)
  71. n, err := rand.Read(randomBytes)
  72. if err != nil {
  73. return nil, ContextError(err)
  74. }
  75. if n != length {
  76. return nil, ContextError(errors.New("insufficient random bytes"))
  77. }
  78. return randomBytes, nil
  79. }
  80. // MakeSecureRandomPadding selects a random padding length in the indicated
  81. // range and returns a random byte array of the selected length.
  82. // In the unlikely case where an underlying MakeRandom functions fails,
  83. // the padding is length 0.
  84. func MakeSecureRandomPadding(minLength, maxLength int) ([]byte, error) {
  85. var padding []byte
  86. paddingSize, err := MakeSecureRandomInt(maxLength - minLength)
  87. if err != nil {
  88. return nil, ContextError(err)
  89. }
  90. paddingSize += minLength
  91. padding, err = MakeSecureRandomBytes(paddingSize)
  92. if err != nil {
  93. return nil, ContextError(err)
  94. }
  95. return padding, nil
  96. }
  97. // MakeRandomPeriod returns a random duration, within a given range.
  98. // In the unlikely case where an underlying MakeRandom functions fails,
  99. // the period is the minimum.
  100. func MakeRandomPeriod(min, max time.Duration) (time.Duration, error) {
  101. period, err := MakeSecureRandomInt64(max.Nanoseconds() - min.Nanoseconds())
  102. if err != nil {
  103. return 0, ContextError(err)
  104. }
  105. return min + time.Duration(period), nil
  106. }
  107. // MakeRandomStringHex returns a hex encoded random string.
  108. // byteLength specifies the pre-encoded data length.
  109. func MakeRandomStringHex(byteLength int) (string, error) {
  110. bytes, err := MakeSecureRandomBytes(byteLength)
  111. if err != nil {
  112. return "", ContextError(err)
  113. }
  114. return hex.EncodeToString(bytes), nil
  115. }
  116. // MakeRandomStringBase64 returns a base64 encoded random string.
  117. // byteLength specifies the pre-encoded data length.
  118. func MakeRandomStringBase64(byteLength int) (string, error) {
  119. bytes, err := MakeSecureRandomBytes(byteLength)
  120. if err != nil {
  121. return "", ContextError(err)
  122. }
  123. return base64.RawURLEncoding.EncodeToString(bytes), nil
  124. }
  125. // GetCurrentTimestamp returns the current time in UTC as
  126. // an RFC 3339 formatted string.
  127. func GetCurrentTimestamp() string {
  128. return time.Now().UTC().Format(time.RFC3339)
  129. }
  130. // TruncateTimestampToHour truncates an RFC 3339 formatted string
  131. // to hour granularity. If the input is not a valid format, the
  132. // result is "".
  133. func TruncateTimestampToHour(timestamp string) string {
  134. t, err := time.Parse(time.RFC3339, timestamp)
  135. if err != nil {
  136. return ""
  137. }
  138. return t.Truncate(1 * time.Hour).Format(time.RFC3339)
  139. }
  140. // getFunctionName is a helper that extracts a simple function name from
  141. // full name returned byruntime.Func.Name(). This is used to declutter
  142. // log messages containing function names.
  143. func getFunctionName(pc uintptr) string {
  144. funcName := runtime.FuncForPC(pc).Name()
  145. index := strings.LastIndex(funcName, "/")
  146. if index != -1 {
  147. funcName = funcName[index+1:]
  148. }
  149. return funcName
  150. }
  151. // GetParentContext returns the parent function name and source file
  152. // line number.
  153. func GetParentContext() string {
  154. pc, _, line, _ := runtime.Caller(2)
  155. return fmt.Sprintf("%s#%d", getFunctionName(pc), line)
  156. }
  157. // ContextError prefixes an error message with the current function
  158. // name and source file line number.
  159. func ContextError(err error) error {
  160. if err == nil {
  161. return nil
  162. }
  163. pc, _, line, _ := runtime.Caller(1)
  164. return fmt.Errorf("%s#%d: %s", getFunctionName(pc), line, err)
  165. }
  166. // Compress returns zlib compressed data
  167. func Compress(data []byte) []byte {
  168. var compressedData bytes.Buffer
  169. writer := zlib.NewWriter(&compressedData)
  170. writer.Write(data)
  171. writer.Close()
  172. return compressedData.Bytes()
  173. }
  174. // Decompress returns zlib decompressed data
  175. func Decompress(data []byte) ([]byte, error) {
  176. reader, err := zlib.NewReader(bytes.NewReader(data))
  177. if err != nil {
  178. return nil, ContextError(err)
  179. }
  180. uncompressedData, err := ioutil.ReadAll(reader)
  181. reader.Close()
  182. if err != nil {
  183. return nil, ContextError(err)
  184. }
  185. return uncompressedData, nil
  186. }