treewide: Replace GPLv2 boilerplate/reference with SPDX - rule 118
[sfrench/cifs-2.6.git] / fs / squashfs / zstd_wrapper.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Squashfs - a compressed read only filesystem for Linux
4  *
5  * Copyright (c) 2016-present, Facebook, Inc.
6  * All rights reserved.
7  *
8  * zstd_wrapper.c
9  */
10
11 #include <linux/mutex.h>
12 #include <linux/buffer_head.h>
13 #include <linux/slab.h>
14 #include <linux/zstd.h>
15 #include <linux/vmalloc.h>
16
17 #include "squashfs_fs.h"
18 #include "squashfs_fs_sb.h"
19 #include "squashfs.h"
20 #include "decompressor.h"
21 #include "page_actor.h"
22
23 struct workspace {
24         void *mem;
25         size_t mem_size;
26         size_t window_size;
27 };
28
29 static void *zstd_init(struct squashfs_sb_info *msblk, void *buff)
30 {
31         struct workspace *wksp = kmalloc(sizeof(*wksp), GFP_KERNEL);
32
33         if (wksp == NULL)
34                 goto failed;
35         wksp->window_size = max_t(size_t,
36                         msblk->block_size, SQUASHFS_METADATA_SIZE);
37         wksp->mem_size = ZSTD_DStreamWorkspaceBound(wksp->window_size);
38         wksp->mem = vmalloc(wksp->mem_size);
39         if (wksp->mem == NULL)
40                 goto failed;
41
42         return wksp;
43
44 failed:
45         ERROR("Failed to allocate zstd workspace\n");
46         kfree(wksp);
47         return ERR_PTR(-ENOMEM);
48 }
49
50
51 static void zstd_free(void *strm)
52 {
53         struct workspace *wksp = strm;
54
55         if (wksp)
56                 vfree(wksp->mem);
57         kfree(wksp);
58 }
59
60
61 static int zstd_uncompress(struct squashfs_sb_info *msblk, void *strm,
62         struct buffer_head **bh, int b, int offset, int length,
63         struct squashfs_page_actor *output)
64 {
65         struct workspace *wksp = strm;
66         ZSTD_DStream *stream;
67         size_t total_out = 0;
68         size_t zstd_err;
69         int k = 0;
70         ZSTD_inBuffer in_buf = { NULL, 0, 0 };
71         ZSTD_outBuffer out_buf = { NULL, 0, 0 };
72
73         stream = ZSTD_initDStream(wksp->window_size, wksp->mem, wksp->mem_size);
74
75         if (!stream) {
76                 ERROR("Failed to initialize zstd decompressor\n");
77                 goto out;
78         }
79
80         out_buf.size = PAGE_SIZE;
81         out_buf.dst = squashfs_first_page(output);
82
83         do {
84                 if (in_buf.pos == in_buf.size && k < b) {
85                         int avail = min(length, msblk->devblksize - offset);
86
87                         length -= avail;
88                         in_buf.src = bh[k]->b_data + offset;
89                         in_buf.size = avail;
90                         in_buf.pos = 0;
91                         offset = 0;
92                 }
93
94                 if (out_buf.pos == out_buf.size) {
95                         out_buf.dst = squashfs_next_page(output);
96                         if (out_buf.dst == NULL) {
97                                 /* Shouldn't run out of pages
98                                  * before stream is done.
99                                  */
100                                 squashfs_finish_page(output);
101                                 goto out;
102                         }
103                         out_buf.pos = 0;
104                         out_buf.size = PAGE_SIZE;
105                 }
106
107                 total_out -= out_buf.pos;
108                 zstd_err = ZSTD_decompressStream(stream, &out_buf, &in_buf);
109                 total_out += out_buf.pos; /* add the additional data produced */
110
111                 if (in_buf.pos == in_buf.size && k < b)
112                         put_bh(bh[k++]);
113         } while (zstd_err != 0 && !ZSTD_isError(zstd_err));
114
115         squashfs_finish_page(output);
116
117         if (ZSTD_isError(zstd_err)) {
118                 ERROR("zstd decompression error: %d\n",
119                                 (int)ZSTD_getErrorCode(zstd_err));
120                 goto out;
121         }
122
123         if (k < b)
124                 goto out;
125
126         return (int)total_out;
127
128 out:
129         for (; k < b; k++)
130                 put_bh(bh[k]);
131
132         return -EIO;
133 }
134
135 const struct squashfs_decompressor squashfs_zstd_comp_ops = {
136         .init = zstd_init,
137         .free = zstd_free,
138         .decompress = zstd_uncompress,
139         .id = ZSTD_COMPRESSION,
140         .name = "zstd",
141         .supported = 1
142 };