##// END OF EJS Templates
branching: merge stable into default
branching: merge stable into default

File last commit:

r54024:f16a7f3c stable
r54034:1b4a024f merge default
Show More
decompressor_multi.rs
289 lines | 9.4 KiB | application/rls-services+xml | RustLexer
// Copyright (c) 2021-present, Gregory Szorc
// All rights reserved.
//
// This software may be modified and distributed under the terms
// of the BSD license. See the LICENSE file for details.
use {
crate::{
buffers::{BufferSegment, ZstdBufferWithSegments, ZstdBufferWithSegmentsCollection},
compression_dict::ZstdCompressionDict,
exceptions::ZstdError,
zstd_safe::DCtx,
},
pyo3::{
buffer::PyBuffer,
exceptions::{PyTypeError, PyValueError},
prelude::*,
types::{PyBytes, PyList, PyTuple},
},
rayon::prelude::*,
};
struct DataSource<'a> {
data: &'a [u8],
decompressed_size: usize,
}
pub fn multi_decompress_to_buffer(
py: Python,
dict_data: Option<&Py<ZstdCompressionDict>>,
frames: &Bound<'_, PyAny>,
decompressed_sizes: Option<&Bound<'_, PyAny>>,
threads: isize,
) -> PyResult<ZstdBufferWithSegmentsCollection> {
let threads = if threads < 0 {
num_cpus::get()
} else if threads < 2 {
1
} else {
threads as _
};
let frame_sizes: &[u64] = if let Some(frames_sizes) = decompressed_sizes {
let buffer: PyBuffer<u8> = PyBuffer::get_bound(&frames_sizes.as_borrowed())?;
unsafe { std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes() / 8) }
} else {
&[]
};
let mut sources = vec![];
if let Ok(buffer) = frames.downcast::<ZstdBufferWithSegments>() {
if decompressed_sizes.is_some() && frame_sizes.len() != buffer.len()? {
return Err(PyValueError::new_err(format!(
"decompressed_sizes size mismatch; expected {}, got {}",
buffer.len()?,
frame_sizes.len()
)));
}
let borrow = buffer.borrow();
sources.reserve_exact(borrow.segments.len());
for i in 0..borrow.segments.len() {
let slice = borrow.get_segment_slice(py, i);
sources.push(DataSource {
data: slice,
decompressed_size: *frame_sizes.get(i).unwrap_or(&0) as _,
});
}
} else if let Ok(collection) = frames.downcast::<ZstdBufferWithSegmentsCollection>() {
let frames_count = collection.len()?;
if decompressed_sizes.is_some() && frame_sizes.len() != frames_count {
return Err(PyValueError::new_err(format!(
"decompressed_sizes size mismatch; expected {}, got {}",
frames_count,
frame_sizes.len()
)));
}
sources.reserve_exact(frames_count);
let mut offset = 0;
for buffer_obj in &collection.borrow().buffers {
let buffer = buffer_obj.downcast_bound::<ZstdBufferWithSegments>(py)?;
let borrow = buffer.borrow();
for i in 0..borrow.segments.len() {
let slice = borrow.get_segment_slice(py, i);
sources.push(DataSource {
data: slice,
decompressed_size: *frame_sizes.get(offset).unwrap_or(&0) as _,
});
offset += 1;
}
}
} else if let Ok(list) = frames.extract::<&PyList>() {
if decompressed_sizes.is_some() && frame_sizes.len() != list.len() {
return Err(PyValueError::new_err(format!(
"decompressed_sizes size mismatch; expected {}; got {}",
list.len(),
frame_sizes.len()
)));
}
sources.reserve_exact(list.len());
for (i, item) in list.iter().enumerate() {
let buffer: PyBuffer<u8> = PyBuffer::get_bound(&item.as_borrowed())
.map_err(|_| PyTypeError::new_err(format!("item {} not a bytes like object", i)))?;
let slice = unsafe {
std::slice::from_raw_parts(buffer.buf_ptr() as *const _, buffer.len_bytes())
};
sources.push(DataSource {
data: slice,
decompressed_size: *frame_sizes.get(i).unwrap_or(&0) as _,
});
}
} else {
return Err(PyTypeError::new_err(
"argument must be list of BufferWithSegments",
));
}
decompress_from_datasources(py, dict_data, sources, threads)
}
#[derive(Debug, PartialEq)]
enum WorkerError {
None,
NoSize,
Zstd(&'static str),
}
/// Holds results of an individual compression operation.
struct WorkerResult {
source_offset: usize,
error: WorkerError,
data: Option<Vec<u8>>,
}
fn decompress_from_datasources(
py: Python,
dict_data: Option<&Py<ZstdCompressionDict>>,
sources: Vec<DataSource>,
thread_count: usize,
) -> PyResult<ZstdBufferWithSegmentsCollection> {
// More threads than inputs makes no sense.
let thread_count = std::cmp::min(thread_count, sources.len());
// TODO lower thread count when input size is too small and threads
// would add overhead.
let mut dctxs = Vec::with_capacity(thread_count);
let results = std::sync::Mutex::new(Vec::with_capacity(sources.len()));
// TODO there are tons of inefficiencies in this implementation compared
// to the C backend.
for _ in 0..thread_count {
let dctx = DCtx::new().map_err(ZstdError::new_err)?;
if let Some(dict_data) = dict_data {
dict_data.borrow_mut(py).load_into_dctx(&dctx)?;
}
dctxs.push(dctx);
}
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(thread_count)
.build()
.map_err(|err| ZstdError::new_err(format!("error initializing thread pool: {}", err)))?;
pool.install(|| {
sources
.par_iter()
.enumerate()
.for_each(|(index, source): (usize, &DataSource)| {
let thread_index = pool.current_thread_index().unwrap();
let dctx = &dctxs[thread_index];
let mut result = WorkerResult {
source_offset: index,
error: WorkerError::None,
data: None,
};
let decompressed_size = if source.decompressed_size == 0 {
let frame_size = match zstd_safe::get_frame_content_size(source.data) {
Err(zstd_safe::ContentSizeError) => {
result.error = WorkerError::NoSize;
zstd_safe::CONTENTSIZE_ERROR
}
Ok(None) => {
result.error = WorkerError::NoSize;
zstd_safe::CONTENTSIZE_UNKNOWN
}
Ok(Some(frame_size)) => frame_size,
};
frame_size as _
} else {
source.decompressed_size
};
if result.error == WorkerError::None {
let mut dest_buffer = Vec::with_capacity(decompressed_size);
let mut in_buffer = zstd_sys::ZSTD_inBuffer {
src: source.data.as_ptr() as *const _,
size: source.data.len(),
pos: 0,
};
match dctx.decompress_into_vec(&mut dest_buffer, &mut in_buffer) {
Ok(_) => {
result.data = Some(dest_buffer);
}
Err(msg) => {
result.error = WorkerError::Zstd(msg);
}
}
}
results.lock().unwrap().push(result);
});
});
// Need to sort results by their input order or else results aren't
// deterministic.
results
.lock()
.unwrap()
.sort_by(|a, b| a.source_offset.cmp(&b.source_offset));
// TODO this is horribly inefficient due to memory copies.
let els = PyTuple::new_bound(
py,
results
.lock()
.unwrap()
.iter()
.map(|result| {
match result.error {
WorkerError::None => Ok(()),
WorkerError::Zstd(msg) => Err(ZstdError::new_err(format!(
"error decompressing item {}: {}",
result.source_offset, msg
))),
WorkerError::NoSize => Err(PyValueError::new_err(format!(
"could not determine decompressed size of item {}",
result.source_offset
))),
}?;
let data = result.data.as_ref().unwrap();
let chunk = PyBytes::new_bound(py, data);
let segments = vec![BufferSegment {
offset: 0,
length: data.len() as _,
}];
let segments = unsafe {
PyBytes::bound_from_ptr(
py,
segments.as_ptr() as *const _,
segments.len() * std::mem::size_of::<BufferSegment>(),
)
};
let segments_buffer = PyBuffer::get_bound(&segments)?;
Py::new(
py,
ZstdBufferWithSegments::new(py, &chunk, segments_buffer)?,
)
})
.collect::<PyResult<Vec<_>>>()?,
);
ZstdBufferWithSegmentsCollection::new(py, &els)
}