Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/tytso/random
[sfrench/cifs-2.6.git] / fs / squashfs / zstd_wrapper.c
1 /*
2  * Squashfs - a compressed read only filesystem for Linux
3  *
4  * Copyright (c) 2016-present, Facebook, Inc.
5  * All rights reserved.
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License
9  * as published by the Free Software Foundation; either version 2,
10  * or (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * zstd_wrapper.c
18  */
19
20 #include <linux/mutex.h>
21 #include <linux/buffer_head.h>
22 #include <linux/slab.h>
23 #include <linux/zstd.h>
24 #include <linux/vmalloc.h>
25
26 #include "squashfs_fs.h"
27 #include "squashfs_fs_sb.h"
28 #include "squashfs.h"
29 #include "decompressor.h"
30 #include "page_actor.h"
31
32 struct workspace {
33         void *mem;
34         size_t mem_size;
35         size_t window_size;
36 };
37
38 static void *zstd_init(struct squashfs_sb_info *msblk, void *buff)
39 {
40         struct workspace *wksp = kmalloc(sizeof(*wksp), GFP_KERNEL);
41
42         if (wksp == NULL)
43                 goto failed;
44         wksp->window_size = max_t(size_t,
45                         msblk->block_size, SQUASHFS_METADATA_SIZE);
46         wksp->mem_size = ZSTD_DStreamWorkspaceBound(wksp->window_size);
47         wksp->mem = vmalloc(wksp->mem_size);
48         if (wksp->mem == NULL)
49                 goto failed;
50
51         return wksp;
52
53 failed:
54         ERROR("Failed to allocate zstd workspace\n");
55         kfree(wksp);
56         return ERR_PTR(-ENOMEM);
57 }
58
59
60 static void zstd_free(void *strm)
61 {
62         struct workspace *wksp = strm;
63
64         if (wksp)
65                 vfree(wksp->mem);
66         kfree(wksp);
67 }
68
69
70 static int zstd_uncompress(struct squashfs_sb_info *msblk, void *strm,
71         struct buffer_head **bh, int b, int offset, int length,
72         struct squashfs_page_actor *output)
73 {
74         struct workspace *wksp = strm;
75         ZSTD_DStream *stream;
76         size_t total_out = 0;
77         size_t zstd_err;
78         int k = 0;
79         ZSTD_inBuffer in_buf = { NULL, 0, 0 };
80         ZSTD_outBuffer out_buf = { NULL, 0, 0 };
81
82         stream = ZSTD_initDStream(wksp->window_size, wksp->mem, wksp->mem_size);
83
84         if (!stream) {
85                 ERROR("Failed to initialize zstd decompressor\n");
86                 goto out;
87         }
88
89         out_buf.size = PAGE_SIZE;
90         out_buf.dst = squashfs_first_page(output);
91
92         do {
93                 if (in_buf.pos == in_buf.size && k < b) {
94                         int avail = min(length, msblk->devblksize - offset);
95
96                         length -= avail;
97                         in_buf.src = bh[k]->b_data + offset;
98                         in_buf.size = avail;
99                         in_buf.pos = 0;
100                         offset = 0;
101                 }
102
103                 if (out_buf.pos == out_buf.size) {
104                         out_buf.dst = squashfs_next_page(output);
105                         if (out_buf.dst == NULL) {
106                                 /* Shouldn't run out of pages
107                                  * before stream is done.
108                                  */
109                                 squashfs_finish_page(output);
110                                 goto out;
111                         }
112                         out_buf.pos = 0;
113                         out_buf.size = PAGE_SIZE;
114                 }
115
116                 total_out -= out_buf.pos;
117                 zstd_err = ZSTD_decompressStream(stream, &out_buf, &in_buf);
118                 total_out += out_buf.pos; /* add the additional data produced */
119
120                 if (in_buf.pos == in_buf.size && k < b)
121                         put_bh(bh[k++]);
122         } while (zstd_err != 0 && !ZSTD_isError(zstd_err));
123
124         squashfs_finish_page(output);
125
126         if (ZSTD_isError(zstd_err)) {
127                 ERROR("zstd decompression error: %d\n",
128                                 (int)ZSTD_getErrorCode(zstd_err));
129                 goto out;
130         }
131
132         if (k < b)
133                 goto out;
134
135         return (int)total_out;
136
137 out:
138         for (; k < b; k++)
139                 put_bh(bh[k]);
140
141         return -EIO;
142 }
143
144 const struct squashfs_decompressor squashfs_zstd_comp_ops = {
145         .init = zstd_init,
146         .free = zstd_free,
147         .decompress = zstd_uncompress,
148         .id = ZSTD_COMPRESSION,
149         .name = "zstd",
150         .supported = 1
151 };