dict.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. package zstd
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "math"
  9. "sort"
  10. "github.com/klauspost/compress/huff0"
  11. )
  12. type dict struct {
  13. id uint32
  14. litEnc *huff0.Scratch
  15. llDec, ofDec, mlDec sequenceDec
  16. offsets [3]int
  17. content []byte
  18. }
  19. const dictMagic = "\x37\xa4\x30\xec"
  20. // Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
  21. const dictMaxLength = 1 << 31
  22. // ID returns the dictionary id or 0 if d is nil.
  23. func (d *dict) ID() uint32 {
  24. if d == nil {
  25. return 0
  26. }
  27. return d.id
  28. }
  29. // ContentSize returns the dictionary content size or 0 if d is nil.
  30. func (d *dict) ContentSize() int {
  31. if d == nil {
  32. return 0
  33. }
  34. return len(d.content)
  35. }
  36. // Content returns the dictionary content.
  37. func (d *dict) Content() []byte {
  38. if d == nil {
  39. return nil
  40. }
  41. return d.content
  42. }
  43. // Offsets returns the initial offsets.
  44. func (d *dict) Offsets() [3]int {
  45. if d == nil {
  46. return [3]int{}
  47. }
  48. return d.offsets
  49. }
  50. // LitEncoder returns the literal encoder.
  51. func (d *dict) LitEncoder() *huff0.Scratch {
  52. if d == nil {
  53. return nil
  54. }
  55. return d.litEnc
  56. }
  57. // Load a dictionary as described in
  58. // https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md#dictionary-format
  59. func loadDict(b []byte) (*dict, error) {
  60. // Check static field size.
  61. if len(b) <= 8+(3*4) {
  62. return nil, io.ErrUnexpectedEOF
  63. }
  64. d := dict{
  65. llDec: sequenceDec{fse: &fseDecoder{}},
  66. ofDec: sequenceDec{fse: &fseDecoder{}},
  67. mlDec: sequenceDec{fse: &fseDecoder{}},
  68. }
  69. if string(b[:4]) != dictMagic {
  70. return nil, ErrMagicMismatch
  71. }
  72. d.id = binary.LittleEndian.Uint32(b[4:8])
  73. if d.id == 0 {
  74. return nil, errors.New("dictionaries cannot have ID 0")
  75. }
  76. // Read literal table
  77. var err error
  78. d.litEnc, b, err = huff0.ReadTable(b[8:], nil)
  79. if err != nil {
  80. return nil, fmt.Errorf("loading literal table: %w", err)
  81. }
  82. d.litEnc.Reuse = huff0.ReusePolicyMust
  83. br := byteReader{
  84. b: b,
  85. off: 0,
  86. }
  87. readDec := func(i tableIndex, dec *fseDecoder) error {
  88. if err := dec.readNCount(&br, uint16(maxTableSymbol[i])); err != nil {
  89. return err
  90. }
  91. if br.overread() {
  92. return io.ErrUnexpectedEOF
  93. }
  94. err = dec.transform(symbolTableX[i])
  95. if err != nil {
  96. println("Transform table error:", err)
  97. return err
  98. }
  99. if debugDecoder || debugEncoder {
  100. println("Read table ok", "symbolLen:", dec.symbolLen)
  101. }
  102. // Set decoders as predefined so they aren't reused.
  103. dec.preDefined = true
  104. return nil
  105. }
  106. if err := readDec(tableOffsets, d.ofDec.fse); err != nil {
  107. return nil, err
  108. }
  109. if err := readDec(tableMatchLengths, d.mlDec.fse); err != nil {
  110. return nil, err
  111. }
  112. if err := readDec(tableLiteralLengths, d.llDec.fse); err != nil {
  113. return nil, err
  114. }
  115. if br.remain() < 12 {
  116. return nil, io.ErrUnexpectedEOF
  117. }
  118. d.offsets[0] = int(br.Uint32())
  119. br.advance(4)
  120. d.offsets[1] = int(br.Uint32())
  121. br.advance(4)
  122. d.offsets[2] = int(br.Uint32())
  123. br.advance(4)
  124. if d.offsets[0] <= 0 || d.offsets[1] <= 0 || d.offsets[2] <= 0 {
  125. return nil, errors.New("invalid offset in dictionary")
  126. }
  127. d.content = make([]byte, br.remain())
  128. copy(d.content, br.unread())
  129. if d.offsets[0] > len(d.content) || d.offsets[1] > len(d.content) || d.offsets[2] > len(d.content) {
  130. return nil, fmt.Errorf("initial offset bigger than dictionary content size %d, offsets: %v", len(d.content), d.offsets)
  131. }
  132. return &d, nil
  133. }
  134. // InspectDictionary loads a zstd dictionary and provides functions to inspect the content.
  135. func InspectDictionary(b []byte) (interface {
  136. ID() uint32
  137. ContentSize() int
  138. Content() []byte
  139. Offsets() [3]int
  140. LitEncoder() *huff0.Scratch
  141. }, error) {
  142. initPredefined()
  143. d, err := loadDict(b)
  144. return d, err
  145. }
  146. type BuildDictOptions struct {
  147. // Dictionary ID.
  148. ID uint32
  149. // Content to use to create dictionary tables.
  150. Contents [][]byte
  151. // History to use for all blocks.
  152. History []byte
  153. // Offsets to use.
  154. Offsets [3]int
  155. // CompatV155 will make the dictionary compatible with Zstd v1.5.5 and earlier.
  156. // See https://github.com/facebook/zstd/issues/3724
  157. CompatV155 bool
  158. // Use the specified encoder level.
  159. // The dictionary will be built using the specified encoder level,
  160. // which will reflect speed and make the dictionary tailored for that level.
  161. // If not set SpeedBestCompression will be used.
  162. Level EncoderLevel
  163. // DebugOut will write stats and other details here if set.
  164. DebugOut io.Writer
  165. }
  166. func BuildDict(o BuildDictOptions) ([]byte, error) {
  167. initPredefined()
  168. hist := o.History
  169. contents := o.Contents
  170. debug := o.DebugOut != nil
  171. println := func(args ...interface{}) {
  172. if o.DebugOut != nil {
  173. fmt.Fprintln(o.DebugOut, args...)
  174. }
  175. }
  176. printf := func(s string, args ...interface{}) {
  177. if o.DebugOut != nil {
  178. fmt.Fprintf(o.DebugOut, s, args...)
  179. }
  180. }
  181. print := func(args ...interface{}) {
  182. if o.DebugOut != nil {
  183. fmt.Fprint(o.DebugOut, args...)
  184. }
  185. }
  186. if int64(len(hist)) > dictMaxLength {
  187. return nil, fmt.Errorf("dictionary of size %d > %d", len(hist), int64(dictMaxLength))
  188. }
  189. if len(hist) < 8 {
  190. return nil, fmt.Errorf("dictionary of size %d < %d", len(hist), 8)
  191. }
  192. if len(contents) == 0 {
  193. return nil, errors.New("no content provided")
  194. }
  195. d := dict{
  196. id: o.ID,
  197. litEnc: nil,
  198. llDec: sequenceDec{},
  199. ofDec: sequenceDec{},
  200. mlDec: sequenceDec{},
  201. offsets: o.Offsets,
  202. content: hist,
  203. }
  204. block := blockEnc{lowMem: false}
  205. block.init()
  206. enc := encoder(&bestFastEncoder{fastBase: fastBase{maxMatchOff: int32(maxMatchLen), bufferReset: math.MaxInt32 - int32(maxMatchLen*2), lowMem: false}})
  207. if o.Level != 0 {
  208. eOpts := encoderOptions{
  209. level: o.Level,
  210. blockSize: maxMatchLen,
  211. windowSize: maxMatchLen,
  212. dict: &d,
  213. lowMem: false,
  214. }
  215. enc = eOpts.encoder()
  216. } else {
  217. o.Level = SpeedBestCompression
  218. }
  219. var (
  220. remain [256]int
  221. ll [256]int
  222. ml [256]int
  223. of [256]int
  224. )
  225. addValues := func(dst *[256]int, src []byte) {
  226. for _, v := range src {
  227. dst[v]++
  228. }
  229. }
  230. addHist := func(dst *[256]int, src *[256]uint32) {
  231. for i, v := range src {
  232. dst[i] += int(v)
  233. }
  234. }
  235. seqs := 0
  236. nUsed := 0
  237. litTotal := 0
  238. newOffsets := make(map[uint32]int, 1000)
  239. for _, b := range contents {
  240. block.reset(nil)
  241. if len(b) < 8 {
  242. continue
  243. }
  244. nUsed++
  245. enc.Reset(&d, true)
  246. enc.Encode(&block, b)
  247. addValues(&remain, block.literals)
  248. litTotal += len(block.literals)
  249. if len(block.sequences) == 0 {
  250. continue
  251. }
  252. seqs += len(block.sequences)
  253. block.genCodes()
  254. addHist(&ll, block.coders.llEnc.Histogram())
  255. addHist(&ml, block.coders.mlEnc.Histogram())
  256. addHist(&of, block.coders.ofEnc.Histogram())
  257. for i, seq := range block.sequences {
  258. if i > 3 {
  259. break
  260. }
  261. offset := seq.offset
  262. if offset == 0 {
  263. continue
  264. }
  265. if int(offset) >= len(o.History) {
  266. continue
  267. }
  268. if offset > 3 {
  269. newOffsets[offset-3]++
  270. } else {
  271. newOffsets[uint32(o.Offsets[offset-1])]++
  272. }
  273. }
  274. }
  275. // Find most used offsets.
  276. var sortedOffsets []uint32
  277. for k := range newOffsets {
  278. sortedOffsets = append(sortedOffsets, k)
  279. }
  280. sort.Slice(sortedOffsets, func(i, j int) bool {
  281. a, b := sortedOffsets[i], sortedOffsets[j]
  282. if a == b {
  283. // Prefer the longer offset
  284. return sortedOffsets[i] > sortedOffsets[j]
  285. }
  286. return newOffsets[sortedOffsets[i]] > newOffsets[sortedOffsets[j]]
  287. })
  288. if len(sortedOffsets) > 3 {
  289. if debug {
  290. print("Offsets:")
  291. for i, v := range sortedOffsets {
  292. if i > 20 {
  293. break
  294. }
  295. printf("[%d: %d],", v, newOffsets[v])
  296. }
  297. println("")
  298. }
  299. sortedOffsets = sortedOffsets[:3]
  300. }
  301. for i, v := range sortedOffsets {
  302. o.Offsets[i] = int(v)
  303. }
  304. if debug {
  305. println("New repeat offsets", o.Offsets)
  306. }
  307. if nUsed == 0 || seqs == 0 {
  308. return nil, fmt.Errorf("%d blocks, %d sequences found", nUsed, seqs)
  309. }
  310. if debug {
  311. println("Sequences:", seqs, "Blocks:", nUsed, "Literals:", litTotal)
  312. }
  313. if seqs/nUsed < 512 {
  314. // Use 512 as minimum.
  315. nUsed = seqs / 512
  316. if nUsed == 0 {
  317. nUsed = 1
  318. }
  319. }
  320. copyHist := func(dst *fseEncoder, src *[256]int) ([]byte, error) {
  321. hist := dst.Histogram()
  322. var maxSym uint8
  323. var maxCount int
  324. var fakeLength int
  325. for i, v := range src {
  326. if v > 0 {
  327. v = v / nUsed
  328. if v == 0 {
  329. v = 1
  330. }
  331. }
  332. if v > maxCount {
  333. maxCount = v
  334. }
  335. if v != 0 {
  336. maxSym = uint8(i)
  337. }
  338. fakeLength += v
  339. hist[i] = uint32(v)
  340. }
  341. // Ensure we aren't trying to represent RLE.
  342. if maxCount == fakeLength {
  343. for i := range hist {
  344. if uint8(i) == maxSym {
  345. fakeLength++
  346. maxSym++
  347. hist[i+1] = 1
  348. if maxSym > 1 {
  349. break
  350. }
  351. }
  352. if hist[0] == 0 {
  353. fakeLength++
  354. hist[i] = 1
  355. if maxSym > 1 {
  356. break
  357. }
  358. }
  359. }
  360. }
  361. dst.HistogramFinished(maxSym, maxCount)
  362. dst.reUsed = false
  363. dst.useRLE = false
  364. err := dst.normalizeCount(fakeLength)
  365. if err != nil {
  366. return nil, err
  367. }
  368. if debug {
  369. println("RAW:", dst.count[:maxSym+1], "NORM:", dst.norm[:maxSym+1], "LEN:", fakeLength)
  370. }
  371. return dst.writeCount(nil)
  372. }
  373. if debug {
  374. print("Literal lengths: ")
  375. }
  376. llTable, err := copyHist(block.coders.llEnc, &ll)
  377. if err != nil {
  378. return nil, err
  379. }
  380. if debug {
  381. print("Match lengths: ")
  382. }
  383. mlTable, err := copyHist(block.coders.mlEnc, &ml)
  384. if err != nil {
  385. return nil, err
  386. }
  387. if debug {
  388. print("Offsets: ")
  389. }
  390. ofTable, err := copyHist(block.coders.ofEnc, &of)
  391. if err != nil {
  392. return nil, err
  393. }
  394. // Literal table
  395. avgSize := litTotal
  396. if avgSize > huff0.BlockSizeMax/2 {
  397. avgSize = huff0.BlockSizeMax / 2
  398. }
  399. huffBuff := make([]byte, 0, avgSize)
  400. // Target size
  401. div := litTotal / avgSize
  402. if div < 1 {
  403. div = 1
  404. }
  405. if debug {
  406. println("Huffman weights:")
  407. }
  408. for i, n := range remain[:] {
  409. if n > 0 {
  410. n = n / div
  411. // Allow all entries to be represented.
  412. if n == 0 {
  413. n = 1
  414. }
  415. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  416. if debug {
  417. printf("[%d: %d], ", i, n)
  418. }
  419. }
  420. }
  421. if o.CompatV155 && remain[255]/div == 0 {
  422. huffBuff = append(huffBuff, 255)
  423. }
  424. scratch := &huff0.Scratch{TableLog: 11}
  425. for tries := 0; tries < 255; tries++ {
  426. scratch = &huff0.Scratch{TableLog: 11}
  427. _, _, err = huff0.Compress1X(huffBuff, scratch)
  428. if err == nil {
  429. break
  430. }
  431. if debug {
  432. printf("Try %d: Huffman error: %v\n", tries+1, err)
  433. }
  434. huffBuff = huffBuff[:0]
  435. if tries == 250 {
  436. if debug {
  437. println("Huffman: Bailing out with predefined table")
  438. }
  439. // Bail out.... Just generate something
  440. huffBuff = append(huffBuff, bytes.Repeat([]byte{255}, 10000)...)
  441. for i := 0; i < 128; i++ {
  442. huffBuff = append(huffBuff, byte(i))
  443. }
  444. continue
  445. }
  446. if errors.Is(err, huff0.ErrIncompressible) {
  447. // Try truncating least common.
  448. for i, n := range remain[:] {
  449. if n > 0 {
  450. n = n / (div * (i + 1))
  451. if n > 0 {
  452. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  453. }
  454. }
  455. }
  456. if o.CompatV155 && len(huffBuff) > 0 && huffBuff[len(huffBuff)-1] != 255 {
  457. huffBuff = append(huffBuff, 255)
  458. }
  459. if len(huffBuff) == 0 {
  460. huffBuff = append(huffBuff, 0, 255)
  461. }
  462. }
  463. if errors.Is(err, huff0.ErrUseRLE) {
  464. for i, n := range remain[:] {
  465. n = n / (div * (i + 1))
  466. // Allow all entries to be represented.
  467. if n == 0 {
  468. n = 1
  469. }
  470. huffBuff = append(huffBuff, bytes.Repeat([]byte{byte(i)}, n)...)
  471. }
  472. }
  473. }
  474. var out bytes.Buffer
  475. out.Write([]byte(dictMagic))
  476. out.Write(binary.LittleEndian.AppendUint32(nil, o.ID))
  477. out.Write(scratch.OutTable)
  478. if debug {
  479. println("huff table:", len(scratch.OutTable), "bytes")
  480. println("of table:", len(ofTable), "bytes")
  481. println("ml table:", len(mlTable), "bytes")
  482. println("ll table:", len(llTable), "bytes")
  483. }
  484. out.Write(ofTable)
  485. out.Write(mlTable)
  486. out.Write(llTable)
  487. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[0])))
  488. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[1])))
  489. out.Write(binary.LittleEndian.AppendUint32(nil, uint32(o.Offsets[2])))
  490. out.Write(hist)
  491. if debug {
  492. _, err := loadDict(out.Bytes())
  493. if err != nil {
  494. panic(err)
  495. }
  496. i, err := InspectDictionary(out.Bytes())
  497. if err != nil {
  498. panic(err)
  499. }
  500. println("ID:", i.ID())
  501. println("Content size:", i.ContentSize())
  502. println("Encoder:", i.LitEncoder() != nil)
  503. println("Offsets:", i.Offsets())
  504. var totalSize int
  505. for _, b := range contents {
  506. totalSize += len(b)
  507. }
  508. encWith := func(opts ...EOption) int {
  509. enc, err := NewWriter(nil, opts...)
  510. if err != nil {
  511. panic(err)
  512. }
  513. defer enc.Close()
  514. var dst []byte
  515. var totalSize int
  516. for _, b := range contents {
  517. dst = enc.EncodeAll(b, dst[:0])
  518. totalSize += len(dst)
  519. }
  520. return totalSize
  521. }
  522. plain := encWith(WithEncoderLevel(o.Level))
  523. withDict := encWith(WithEncoderLevel(o.Level), WithEncoderDict(out.Bytes()))
  524. println("Input size:", totalSize)
  525. println("Plain Compressed:", plain)
  526. println("Dict Compressed:", withDict)
  527. println("Saved:", plain-withDict, (plain-withDict)/len(contents), "bytes per input (rounded down)")
  528. }
  529. return out.Bytes(), nil
  530. }