1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zstd
6
7import (
8	"io"
9	"math/bits"
10)
11
12// maxHuffmanBits is the largest possible Huffman table bits.
13const maxHuffmanBits = 11
14
15// readHuff reads Huffman table from data starting at off into table.
16// Each entry in a Huffman table is a pair of bytes.
17// The high byte is the encoded value. The low byte is the number
18// of bits used to encode that value. We index into the table
19// with a value of size tableBits. A value that requires fewer bits
20// appear in the table multiple times.
21// This returns the number of bits in the Huffman table and the new offset.
22// RFC 4.2.1.
23func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
24	if off >= len(data) {
25		return 0, 0, r.makeEOFError(off)
26	}
27
28	hdr := data[off]
29	off++
30
31	var weights [256]uint8
32	var count int
33	if hdr < 128 {
34		// The table is compressed using an FSE. RFC 4.2.1.2.
35		if len(r.fseScratch) < 1<<6 {
36			r.fseScratch = make([]fseEntry, 1<<6)
37		}
38		fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
39		if err != nil {
40			return 0, 0, err
41		}
42		fseTable := r.fseScratch
43
44		if off+int(hdr) > len(data) {
45			return 0, 0, r.makeEOFError(off)
46		}
47
48		rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
49		if err != nil {
50			return 0, 0, err
51		}
52
53		state1, err := rbr.val(uint8(fseBits))
54		if err != nil {
55			return 0, 0, err
56		}
57
58		state2, err := rbr.val(uint8(fseBits))
59		if err != nil {
60			return 0, 0, err
61		}
62
63		// There are two independent FSE streams, tracked by
64		// state1 and state2. We decode them alternately.
65
66		for {
67			pt := &fseTable[state1]
68			if !rbr.fetch(pt.bits) {
69				if count >= 254 {
70					return 0, 0, rbr.makeError("Huffman count overflow")
71				}
72				weights[count] = pt.sym
73				weights[count+1] = fseTable[state2].sym
74				count += 2
75				break
76			}
77
78			v, err := rbr.val(pt.bits)
79			if err != nil {
80				return 0, 0, err
81			}
82			state1 = uint32(pt.base) + v
83
84			if count >= 255 {
85				return 0, 0, rbr.makeError("Huffman count overflow")
86			}
87
88			weights[count] = pt.sym
89			count++
90
91			pt = &fseTable[state2]
92
93			if !rbr.fetch(pt.bits) {
94				if count >= 254 {
95					return 0, 0, rbr.makeError("Huffman count overflow")
96				}
97				weights[count] = pt.sym
98				weights[count+1] = fseTable[state1].sym
99				count += 2
100				break
101			}
102
103			v, err = rbr.val(pt.bits)
104			if err != nil {
105				return 0, 0, err
106			}
107			state2 = uint32(pt.base) + v
108
109			if count >= 255 {
110				return 0, 0, rbr.makeError("Huffman count overflow")
111			}
112
113			weights[count] = pt.sym
114			count++
115		}
116
117		off += int(hdr)
118	} else {
119		// The table is not compressed. Each weight is 4 bits.
120
121		count = int(hdr) - 127
122		if off+((count+1)/2) >= len(data) {
123			return 0, 0, io.ErrUnexpectedEOF
124		}
125		for i := 0; i < count; i += 2 {
126			b := data[off]
127			off++
128			weights[i] = b >> 4
129			weights[i+1] = b & 0xf
130		}
131	}
132
133	// RFC 4.2.1.3.
134
135	var weightMark [13]uint32
136	weightMask := uint32(0)
137	for _, w := range weights[:count] {
138		if w > 12 {
139			return 0, 0, r.makeError(off, "Huffman weight overflow")
140		}
141		weightMark[w]++
142		if w > 0 {
143			weightMask += 1 << (w - 1)
144		}
145	}
146	if weightMask == 0 {
147		return 0, 0, r.makeError(off, "bad Huffman weights")
148	}
149
150	tableBits = 32 - bits.LeadingZeros32(weightMask)
151	if tableBits > maxHuffmanBits {
152		return 0, 0, r.makeError(off, "bad Huffman weights")
153	}
154
155	if len(table) < 1<<tableBits {
156		return 0, 0, r.makeError(off, "Huffman table too small")
157	}
158
159	// Work out the last weight value, which is omitted because
160	// the weights must sum to a power of two.
161	left := (uint32(1) << tableBits) - weightMask
162	if left == 0 {
163		return 0, 0, r.makeError(off, "bad Huffman weights")
164	}
165	highBit := 31 - bits.LeadingZeros32(left)
166	if uint32(1)<<highBit != left {
167		return 0, 0, r.makeError(off, "bad Huffman weights")
168	}
169	if count >= 256 {
170		return 0, 0, r.makeError(off, "Huffman weight overflow")
171	}
172	weights[count] = uint8(highBit + 1)
173	count++
174	weightMark[highBit+1]++
175
176	if weightMark[1] < 2 || weightMark[1]&1 != 0 {
177		return 0, 0, r.makeError(off, "bad Huffman weights")
178	}
179
180	// Change weightMark from a count of weights to the index of
181	// the first symbol for that weight. We shift the indexes to
182	// also store how many we have seen so far,
183	next := uint32(0)
184	for i := 0; i < tableBits; i++ {
185		cur := next
186		next += weightMark[i+1] << i
187		weightMark[i+1] = cur
188	}
189
190	for i, w := range weights[:count] {
191		if w == 0 {
192			continue
193		}
194		length := uint32(1) << (w - 1)
195		tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
196		start := weightMark[w]
197		for j := uint32(0); j < length; j++ {
198			table[start+j] = tval
199		}
200		weightMark[w] += length
201	}
202
203	return tableBits, off, nil
204}
205