Skip to content

Commit

Permalink
Add logic to compute channel grids for modular.
Browse files Browse the repository at this point in the history
Also compute the relevant transforms "decomposed" to operate on grid
chunks, and the dependencies in decoding (i.e. which inputs are needed
for which transform).
  • Loading branch information
veluca93 committed Jan 1, 2025
1 parent 6600b6b commit ddb0ba5
Show file tree
Hide file tree
Showing 3 changed files with 395 additions and 14 deletions.
53 changes: 42 additions & 11 deletions jxl/src/frame/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
extra_channels::ExtraChannelInfo, frame_header::FrameHeader, modular::GroupHeader,
JxlHeader,
},
image::Image,
util::{tracing_wrappers::*, CeilLog2},
};

Expand All @@ -20,7 +21,7 @@ mod transforms;
mod tree;

pub use predict::Predictor;
use transforms::TransformStep;
use transforms::{make_grids, TransformStepChunk};
pub use tree::Tree;

#[derive(Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -58,6 +59,27 @@ impl ChannelInfo {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
enum ModularGridKind {
// Single big channel.
None,
// 2048x2048 image-pixels.
Lf,
// 256x256 image-pixels.
Hf,
}

#[allow(dead_code)]
#[derive(Debug)]
struct ModularBuffer {
data: Option<Image<i32>>,
// Holds additional information such as the weighted predictor's error channel's last row for
// the transform chunk that produced this buffer.
auxiliary_data: Option<Image<i32>>,
remaining_uses: usize,
used_by_transforms: Vec<usize>,
}

#[allow(dead_code)]
#[derive(Debug)]
struct ModularBufferInfo {
Expand All @@ -67,6 +89,8 @@ struct ModularBufferInfo {
is_output: bool,
is_coded: bool,
description: String,
grid_kind: ModularGridKind,
buffer_grid: Vec<ModularBuffer>,
}

/// A modular image is a sequence of channels to which one or more transforms might have been
Expand All @@ -81,10 +105,10 @@ struct ModularBufferInfo {
#[derive(Debug)]
pub struct FullModularImage {
buffer_info: Vec<ModularBufferInfo>,
transform_steps: Vec<TransformStep>,
transform_steps: Vec<TransformStepChunk>,
// List of buffer indices of the channels of the modular image encoded in each kind of section.
// In order, LfGlobal, LfGroup, HfGroup(pass 0), ..., HfGroup(last pass).
section_modular_images: Vec<Vec<usize>>,
section_buffer_indices: Vec<Vec<usize>>,
}

impl FullModularImage {
Expand Down Expand Up @@ -127,13 +151,13 @@ impl FullModularImage {
return Err(Error::NoGlobalTree);
}

let (buffer_info, transform_steps) =
let (mut buffer_info, transform_steps) =
transforms::meta_apply_transforms(&channels, &header.transforms)?;

// Assign each (channel, group) pair present in the bitstream to the section in which it will be decoded.
let mut section_modular_images: Vec<Vec<usize>> = vec![];
let mut section_buffer_indices: Vec<Vec<usize>> = vec![];

section_modular_images.push(
section_buffer_indices.push(
buffer_info
.iter()
.enumerate()
Expand All @@ -143,7 +167,7 @@ impl FullModularImage {
.collect(),
);

section_modular_images.push(
section_buffer_indices.push(
buffer_info
.iter()
.enumerate()
Expand All @@ -156,7 +180,7 @@ impl FullModularImage {

for pass in 0..frame_header.passes.num_passes as usize {
let (min_shift, max_shift) = frame_header.passes.downsampling_bracket(pass);
section_modular_images.push(
section_buffer_indices.push(
buffer_info
.iter()
.enumerate()
Expand All @@ -169,18 +193,25 @@ impl FullModularImage {
}

// Ensure that the channel list in each group is sorted by actual channel ID.
for list in section_modular_images.iter_mut() {
for list in section_buffer_indices.iter_mut() {
list.sort_by_key(|x| buffer_info[*x].channel_id);
}

// TODO(veluca93): prepare grids, grid sizes and dependency counts for the various grids.
trace!(?section_buffer_indices);

let transform_steps = make_grids(
frame_header,
transform_steps,
&section_buffer_indices,
&mut buffer_info,
);

// TODO(veluca93): read global channels

Ok(FullModularImage {
buffer_info,
transform_steps,
section_modular_images,
section_buffer_indices,
})
}
}
2 changes: 1 addition & 1 deletion jxl/src/frame/modular/predict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use num_derive::FromPrimitive;
use num_traits::FromPrimitive;

#[repr(u8)]
#[derive(Debug, FromPrimitive, Clone, Copy)]
#[derive(Debug, FromPrimitive, Clone, Copy, PartialEq, Eq)]
pub enum Predictor {
Zero = 0,
West = 1,
Expand Down
Loading

0 comments on commit ddb0ba5

Please sign in to comment.