weightedrand.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. // Package weightedrand contains a performant data structure and algorithm used
  2. // to randomly select an element from some kind of list, where the chances of
  3. // each element to be selected not being equal, but defined by relative
  4. // "weights" (or probabilities). This is called weighted random selection.
  5. //
  6. // Compare this package with (github.com/jmcvetta/randutil).WeightedChoice,
  7. // which is optimized for the single operation case. In contrast, this package
  8. // creates a presorted cache optimized for binary search, allowing for repeated
  9. // selections from the same set to be significantly faster, especially for large
  10. // data sets.
  11. package weightedrand
  12. import (
  13. "errors"
  14. "math/rand"
  15. "sort"
  16. )
  17. // Choice is a generic wrapper that can be used to add weights for any item.
  18. type Choice struct {
  19. Item interface{}
  20. Weight uint
  21. }
  22. // NewChoice creates a new Choice with specified item and weight.
  23. func NewChoice(item interface{}, weight uint) Choice {
  24. return Choice{Item: item, Weight: weight}
  25. }
  26. // A Chooser caches many possible Choices in a structure designed to improve
  27. // performance on repeated calls for weighted random selection.
  28. type Chooser struct {
  29. data []Choice
  30. totals []int
  31. max int
  32. }
  33. // NewChooser initializes a new Chooser for picking from the provided choices.
  34. func NewChooser(choices ...Choice) (*Chooser, error) {
  35. sort.Slice(choices, func(i, j int) bool {
  36. return choices[i].Weight < choices[j].Weight
  37. })
  38. totals := make([]int, len(choices))
  39. runningTotal := 0
  40. for i, c := range choices {
  41. weight := int(c.Weight)
  42. if (maxInt - runningTotal) <= weight {
  43. return nil, errWeightOverflow
  44. }
  45. runningTotal += weight
  46. totals[i] = runningTotal
  47. }
  48. if runningTotal < 1 {
  49. return nil, errNoValidChoices
  50. }
  51. return &Chooser{data: choices, totals: totals, max: runningTotal}, nil
  52. }
  53. const (
  54. intSize = 32 << (^uint(0) >> 63) // cf. strconv.IntSize
  55. maxInt = 1<<(intSize-1) - 1
  56. )
  57. // Possible errors returned by NewChooser, preventing the creation of a Chooser
  58. // with unsafe runtime states.
  59. var (
  60. // If the sum of provided Choice weights exceed the maximum integer value
  61. // for the current platform (e.g. math.MaxInt32 or math.MaxInt64), then
  62. // the internal running total will overflow, resulting in an imbalanced
  63. // distribution generating improper results.
  64. errWeightOverflow = errors.New("sum of Choice Weights exceeds max int")
  65. // If there are no Choices available to the Chooser with a weight >= 1,
  66. // there are no valid choices and Pick would produce a runtime panic.
  67. errNoValidChoices = errors.New("zero Choices with Weight >= 1")
  68. )
  69. // Pick returns a single weighted random Choice.Item from the Chooser.
  70. //
  71. // Utilizes global rand as the source of randomness.
  72. func (c Chooser) Pick() interface{} {
  73. r := rand.Intn(c.max) + 1
  74. i := searchInts(c.totals, r)
  75. return c.data[i].Item
  76. }
  77. // PickSource returns a single weighted random Choice.Item from the Chooser,
  78. // utilizing the provided *rand.Rand source rs for randomness.
  79. //
  80. // The primary use-case for this is avoid lock contention from the global random
  81. // source if utilizing Chooser(s) from multiple goroutines in extremely
  82. // high-throughput situations.
  83. //
  84. // It is the responsibility of the caller to ensure the provided rand.Source is
  85. // free from thread safety issues.
  86. func (c Chooser) PickSource(rs *rand.Rand) interface{} {
  87. r := rs.Intn(c.max) + 1
  88. i := searchInts(c.totals, r)
  89. return c.data[i].Item
  90. }
  91. // The standard library sort.SearchInts() just wraps the generic sort.Search()
  92. // function, which takes a function closure to determine truthfulness. However,
  93. // since this function is utilized within a for loop, it cannot currently be
  94. // properly inlined by the compiler, resulting in non-trivial performance
  95. // overhead.
  96. //
  97. // Thus, this is essentially manually inlined version. In our use case here, it
  98. // results in a up to ~33% overall throughput increase for Pick().
  99. func searchInts(a []int, x int) int {
  100. // Possible further future optimization for searchInts via SIMD if we want
  101. // to write some Go assembly code: http://0x80.pl/articles/simd-search.html
  102. i, j := 0, len(a)
  103. for i < j {
  104. h := int(uint(i+j) >> 1) // avoid overflow when computing h
  105. if a[h] < x {
  106. i = h + 1
  107. } else {
  108. j = h
  109. }
  110. }
  111. return i
  112. }