1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zstd
6
7import (
8	"encoding/binary"
9)
10
11// readLiterals reads and decompresses the literals from data at off.
12// The literals are appended to outbuf, which is returned.
13// Also returns the new input offset. RFC 3.1.1.3.1.
14func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) {
15	if off >= len(data) {
16		return 0, nil, r.makeEOFError(off)
17	}
18
19	// Literals section header. RFC 3.1.1.3.1.1.
20	hdr := data[off]
21	off++
22
23	if (hdr&3) == 0 || (hdr&3) == 1 {
24		return r.readRawRLELiterals(data, off, hdr, outbuf)
25	} else {
26		return r.readHuffLiterals(data, off, hdr, outbuf)
27	}
28}
29
30// readRawRLELiterals reads and decompresses a Raw_Literals_Block or
31// a RLE_Literals_Block. RFC 3.1.1.3.1.1.
32func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
33	raw := (hdr & 3) == 0
34
35	var regeneratedSize int
36	switch (hdr >> 2) & 3 {
37	case 0, 2:
38		regeneratedSize = int(hdr >> 3)
39	case 1:
40		if off >= len(data) {
41			return 0, nil, r.makeEOFError(off)
42		}
43		regeneratedSize = int(hdr>>4) + (int(data[off]) << 4)
44		off++
45	case 3:
46		if off+1 >= len(data) {
47			return 0, nil, r.makeEOFError(off)
48		}
49		regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12)
50		off += 2
51	}
52
53	// We are going to use the entire literal block in the output.
54	// The maximum size of one decompressed block is 128K,
55	// so we can't have more literals than that.
56	if regeneratedSize > 128<<10 {
57		return 0, nil, r.makeError(off, "literal size too large")
58	}
59
60	if raw {
61		// RFC 3.1.1.3.1.2.
62		if off+regeneratedSize > len(data) {
63			return 0, nil, r.makeError(off, "raw literal size too large")
64		}
65		outbuf = append(outbuf, data[off:off+regeneratedSize]...)
66		off += regeneratedSize
67	} else {
68		// RFC 3.1.1.3.1.3.
69		if off >= len(data) {
70			return 0, nil, r.makeError(off, "RLE literal missing")
71		}
72		rle := data[off]
73		off++
74		for i := 0; i < regeneratedSize; i++ {
75			outbuf = append(outbuf, rle)
76		}
77	}
78
79	return off, outbuf, nil
80}
81
82// readHuffLiterals reads and decompresses a Compressed_Literals_Block or
83// a Treeless_Literals_Block. RFC 3.1.1.3.1.4.
84func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
85	var (
86		regeneratedSize int
87		compressedSize  int
88		streams         int
89	)
90	switch (hdr >> 2) & 3 {
91	case 0, 1:
92		if off+1 >= len(data) {
93			return 0, nil, r.makeEOFError(off)
94		}
95		regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4)
96		compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2)
97		off += 2
98		if ((hdr >> 2) & 3) == 0 {
99			streams = 1
100		} else {
101			streams = 4
102		}
103	case 2:
104		if off+2 >= len(data) {
105			return 0, nil, r.makeEOFError(off)
106		}
107		regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12)
108		compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6)
109		off += 3
110		streams = 4
111	case 3:
112		if off+3 >= len(data) {
113			return 0, nil, r.makeEOFError(off)
114		}
115		regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12)
116		compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10)
117		off += 4
118		streams = 4
119	}
120
121	// We are going to use the entire literal block in the output.
122	// The maximum size of one decompressed block is 128K,
123	// so we can't have more literals than that.
124	if regeneratedSize > 128<<10 {
125		return 0, nil, r.makeError(off, "literal size too large")
126	}
127
128	roff := off + compressedSize
129	if roff > len(data) || roff < 0 {
130		return 0, nil, r.makeEOFError(off)
131	}
132
133	totalStreamsSize := compressedSize
134	if (hdr & 3) == 2 {
135		// Compressed_Literals_Block.
136		// Read new huffman tree.
137
138		if len(r.huffmanTable) < 1<<maxHuffmanBits {
139			r.huffmanTable = make([]uint16, 1<<maxHuffmanBits)
140		}
141
142		huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable)
143		if err != nil {
144			return 0, nil, err
145		}
146		r.huffmanTableBits = huffmanTableBits
147
148		if totalStreamsSize < hoff-off {
149			return 0, nil, r.makeError(off, "Huffman table too big")
150		}
151		totalStreamsSize -= hoff - off
152		off = hoff
153	} else {
154		// Treeless_Literals_Block
155		// Reuse previous Huffman tree.
156		if r.huffmanTableBits == 0 {
157			return 0, nil, r.makeError(off, "missing literals Huffman tree")
158		}
159	}
160
161	// Decompress compressedSize bytes of data at off using the
162	// Huffman tree.
163
164	var err error
165	if streams == 1 {
166		outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf)
167	} else {
168		outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf)
169	}
170
171	if err != nil {
172		return 0, nil, err
173	}
174
175	return roff, outbuf, nil
176}
177
178// readLiteralsOneStream reads a single stream of compressed literals.
179func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
180	// We let the reverse bit reader read earlier bytes,
181	// because the Huffman table ignores bits that it doesn't need.
182	rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2)
183	if err != nil {
184		return nil, err
185	}
186
187	huffTable := r.huffmanTable
188	huffBits := uint32(r.huffmanTableBits)
189	huffMask := (uint32(1) << huffBits) - 1
190
191	for i := 0; i < regeneratedSize; i++ {
192		if !rbr.fetch(uint8(huffBits)) {
193			return nil, rbr.makeError("literals Huffman stream out of bits")
194		}
195
196		var t uint16
197		idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
198		t = huffTable[idx]
199		outbuf = append(outbuf, byte(t>>8))
200		rbr.cnt -= uint32(t & 0xff)
201	}
202
203	return outbuf, nil
204}
205
206// readLiteralsFourStreams reads four interleaved streams of
207// compressed literals.
208func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
209	// Read the jump table to find out where the streams are.
210	// RFC 3.1.1.3.1.6.
211	if off+5 >= len(data) {
212		return nil, r.makeEOFError(off)
213	}
214	if totalStreamsSize < 6 {
215		return nil, r.makeError(off, "total streams size too small for jump table")
216	}
217	// RFC 3.1.1.3.1.6.
218	// "The decompressed size of each stream is equal to (Regenerated_Size+3)/4,
219	// except for the last stream, which may be up to 3 bytes smaller,
220	// to reach a total decompressed size as specified in Regenerated_Size."
221	regeneratedStreamSize := (regeneratedSize + 3) / 4
222	if regeneratedSize < regeneratedStreamSize*3 {
223		return nil, r.makeError(off, "regenerated size too small to decode streams")
224	}
225
226	streamSize1 := binary.LittleEndian.Uint16(data[off:])
227	streamSize2 := binary.LittleEndian.Uint16(data[off+2:])
228	streamSize3 := binary.LittleEndian.Uint16(data[off+4:])
229	off += 6
230
231	tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3)
232	if tot > uint64(totalStreamsSize)-6 {
233		return nil, r.makeEOFError(off)
234	}
235	streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot)
236
237	off--
238	off1 := off + int(streamSize1)
239	start1 := off + 1
240
241	off2 := off1 + int(streamSize2)
242	start2 := off1 + 1
243
244	off3 := off2 + int(streamSize3)
245	start3 := off2 + 1
246
247	off4 := off3 + int(streamSize4)
248	start4 := off3 + 1
249
250	// We let the reverse bit readers read earlier bytes,
251	// because the Huffman tables ignore bits that they don't need.
252
253	rbr1, err := r.makeReverseBitReader(data, off1, start1-2)
254	if err != nil {
255		return nil, err
256	}
257
258	rbr2, err := r.makeReverseBitReader(data, off2, start2-2)
259	if err != nil {
260		return nil, err
261	}
262
263	rbr3, err := r.makeReverseBitReader(data, off3, start3-2)
264	if err != nil {
265		return nil, err
266	}
267
268	rbr4, err := r.makeReverseBitReader(data, off4, start4-2)
269	if err != nil {
270		return nil, err
271	}
272
273	out1 := len(outbuf)
274	out2 := out1 + regeneratedStreamSize
275	out3 := out2 + regeneratedStreamSize
276	out4 := out3 + regeneratedStreamSize
277
278	regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3
279
280	outbuf = append(outbuf, make([]byte, regeneratedSize)...)
281
282	huffTable := r.huffmanTable
283	huffBits := uint32(r.huffmanTableBits)
284	huffMask := (uint32(1) << huffBits) - 1
285
286	for i := 0; i < regeneratedStreamSize; i++ {
287		use4 := i < regeneratedStreamSize4
288
289		fetchHuff := func(rbr *reverseBitReader) (uint16, error) {
290			if !rbr.fetch(uint8(huffBits)) {
291				return 0, rbr.makeError("literals Huffman stream out of bits")
292			}
293			idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
294			return huffTable[idx], nil
295		}
296
297		t1, err := fetchHuff(&rbr1)
298		if err != nil {
299			return nil, err
300		}
301
302		t2, err := fetchHuff(&rbr2)
303		if err != nil {
304			return nil, err
305		}
306
307		t3, err := fetchHuff(&rbr3)
308		if err != nil {
309			return nil, err
310		}
311
312		if use4 {
313			t4, err := fetchHuff(&rbr4)
314			if err != nil {
315				return nil, err
316			}
317			outbuf[out4] = byte(t4 >> 8)
318			out4++
319			rbr4.cnt -= uint32(t4 & 0xff)
320		}
321
322		outbuf[out1] = byte(t1 >> 8)
323		out1++
324		rbr1.cnt -= uint32(t1 & 0xff)
325
326		outbuf[out2] = byte(t2 >> 8)
327		out2++
328		rbr2.cnt -= uint32(t2 & 0xff)
329
330		outbuf[out3] = byte(t3 >> 8)
331		out3++
332		rbr3.cnt -= uint32(t3 & 0xff)
333	}
334
335	return outbuf, nil
336}
337