Skip to content

Commit

Permalink
- add support for more ray tracing shader types in raytracing pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
polymonster authored and GBDixonAlex committed Jan 11, 2025
1 parent 7179676 commit b33efce
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 60 deletions.
26 changes: 0 additions & 26 deletions examples/raytraced_triangle/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,10 @@ fn main() -> Result<(), hotline_rs::Error> {
let mut swap_chain = device.create_swap_chain::<os_platform::App>(&swap_chain_info, &window)?;
let mut cmd = device.create_cmd_buf(num_buffers);


let mut pmfx : pmfx::Pmfx<gfx_platform::Device> = pmfx::Pmfx::create(&mut device, 0);
pmfx.load(&hotline_rs::get_data_path("shaders/raytracing_example"))?;
pmfx.create_raytracing_pipeline(&device, "raytracing");

/*
// create raytracing shaders
let raygen_shader = device.create_shader(&gfx::ShaderInfo {
shader_type: gfx::ShaderType::RayGen,
compile_info: None
}, &fs::read(hotline_rs::get_data_path("shaders/raygen.cso"))?)?;
let closest_hit_shader = device.create_shader(&gfx::ShaderInfo {
shader_type: gfx::ShaderType::ClosestHit,
compile_info: None
}, &fs::read(hotline_rs::get_data_path("shaders/closesthit.cso"))?)?;
let miss_shader = device.create_shader(&gfx::ShaderInfo {
shader_type: gfx::ShaderType::Miss,
compile_info: None
}, &fs::read(hotline_rs::get_data_path("shaders/miss.cso"))?)?;
// create raytracing pipeline
let raytracing_pipeline = device.create_raytracing_pipeline(&RaytracingPipelineInfo{
raygen_shader: Some((&raygen_shader, "MyRaygenShader")),
closest_hit_shader: None, //Some((&closest_hit_shader, "MyClosestHitShader")),
miss_shader: None, // Some((&miss_shader, "MyMissShader")),
});
*/

while app.run() {
// update window and swap chain
window.update(&mut app);
Expand Down
20 changes: 14 additions & 6 deletions src/gfx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,15 +746,23 @@ pub struct ComputePipelineInfo<'stack, D: Device> {
pub pipeline_layout: PipelineLayout,
}

pub struct RaytracingShader<'stack, D: Device> {
pub shader: &'stack D::Shader,
pub entry_point: String,
pub stage: ShaderType
}

/// Information to create a raytracing pipeline through `Device::create_raytracing_pipeline`
pub struct RaytracingPipelineInfo<'stack, D: Device> {
pub raygen_shader: Option<(&'stack D::Shader, &'stack str)>,
pub any_hit_shader: Option<(&'stack D::Shader, &'stack str)>,
pub closest_hit_shader: Option<(&'stack D::Shader, &'stack str)>,
pub miss_shader: Option<(&'stack D::Shader, &'stack str)>,
pub intersection_shader: Option<(&'stack D::Shader, &'stack str)>,
pub callable_shader: Option<(&'stack D::Shader, &'stack str)>,

//pub raygen_shader: Option<(&'stack D::Shader, String)>,
//pub any_hit_shader: Option<(&'stack D::Shader, String)>,
//pub closest_hit_shader: Option<(&'stack D::Shader, String)>,
//pub miss_shader: Option<(&'stack D::Shader, String)>,
//pub intersection_shader: Option<(&'stack D::Shader, String)>,
//pub callable_shader: Option<(&'stack D::Shader, String)>,

pub shaders: Vec<RaytracingShader<'stack, D>>,
pub pipeline_layout: PipelineLayout,
}

Expand Down
50 changes: 28 additions & 22 deletions src/gfx/d3d12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2967,9 +2967,6 @@ impl super::Device for Device {

let mut subobjects = Vec::new();

let (shader, entry) = info.raygen_shader.unwrap();
let wide_name = os::win32::string_to_wide(entry.to_string());

// TODO: triangle hit group

// rt shader config
Expand All @@ -2995,25 +2992,34 @@ impl super::Device for Device {
};
subobjects.push(pipeline_config_subobject);

// dxil library
let dxil_library = D3D12_DXIL_LIBRARY_DESC {
DXILLibrary: D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.get_buffer_pointer(),
BytecodeLength: shader.get_buffer_size(),
},
NumExports: 1,
pExports: &D3D12_EXPORT_DESC {
Name: windows_core::PCWSTR(wide_name.as_ptr() as *const _),
ExportToRename: windows_core::PCWSTR(std::ptr::null()),
Flags: D3D12_EXPORT_FLAGS(0),
},
..Default::default()
};
let dxil_library_subobject = D3D12_STATE_SUBOBJECT {
Type: D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY,
pDesc: &dxil_library as *const _ as *const _,
};
subobjects.push(dxil_library_subobject);
// widen entry point strings
let wide_entry_points: Vec<_> = info.shaders
.iter()
.map(|x| os::win32::string_to_wide(x.entry_point.clone()))
.collect();

// dxil library shaders
for (index, shader) in info.shaders.iter().enumerate() {
// dxil library
let dxil_library = D3D12_DXIL_LIBRARY_DESC {
DXILLibrary: D3D12_SHADER_BYTECODE {
pShaderBytecode: shader.shader.get_buffer_pointer(),
BytecodeLength: shader.shader.get_buffer_size(),
},
NumExports: 1,
pExports: &D3D12_EXPORT_DESC {
Name: windows_core::PCWSTR(wide_entry_points[index].as_ptr() as *const _),
ExportToRename: windows_core::PCWSTR(std::ptr::null()),
Flags: D3D12_EXPORT_FLAGS(0),
},
..Default::default()
};
let dxil_library_subobject = D3D12_STATE_SUBOBJECT {
Type: D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY,
pDesc: &dxil_library as *const _ as *const _,
};
subobjects.push(dxil_library_subobject);
}

// root signature, for now we use a global one per pipeline
let root_signature = self.create_root_signature_with_lookup(&info.pipeline_layout)?;
Expand Down
43 changes: 37 additions & 6 deletions src/pmfx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,18 @@ fn to_gfx_clear_depth_stencil(clear_depth: Option<f32>, clear_stencil: Option<u8
}
}

fn get_shader_entry_point_name(shader_name: Option<String>) -> Option<String> {
if let Some(shader_name) = shader_name {
Path::new(&shader_name)
.file_stem()
.and_then(|os_str| os_str.to_str())
.map(|s| s.to_string())
}
else {
None
}
}

impl<D> Pmfx<D> where D: gfx::Device {
/// Create a new empty pmfx instance
pub fn create(device: &mut D, shader_heap_size: usize) -> Self {
Expand Down Expand Up @@ -2092,16 +2104,35 @@ impl<D> Pmfx<D> where D: gfx::Device {
self.create_shader(device, Path::new(&folder), &pipeline.ch)?;
self.create_shader(device, Path::new(&folder), &pipeline.ah)?;
self.create_shader(device, Path::new(&folder), &pipeline.mi)?;
self.create_shader(device, Path::new(&folder), &pipeline.is)?;
self.create_shader(device, Path::new(&folder), &pipeline.ca)?;
}

for (_, pipeline) in self.pmfx.pipelines[pipeline_name].clone() {

// build shader info vector
let stages = vec![
(pipeline.rg, gfx::ShaderType::RayGen),
(pipeline.ch, gfx::ShaderType::ClosestHit),
(pipeline.ah, gfx::ShaderType::AnyHit),
(pipeline.mi, gfx::ShaderType::Miss),
(pipeline.is, gfx::ShaderType::Intersection),
(pipeline.ca, gfx::ShaderType::Callable),
];

let shaders = stages
.iter()
.filter(|x| x.0.is_some())
.map(|x| (x.0.as_ref().unwrap(), x.1))
.map(|x| gfx::RaytracingShader {
shader: self.get_shader(&Some(x.0.to_string())).unwrap(),
entry_point: get_shader_entry_point_name(Some(x.0.to_string())).unwrap(),
stage: x.1
})
.collect();

let raytracing_pipeline = device.create_raytracing_pipeline(&RaytracingPipelineInfo{
raygen_shader: if let Some(rg) = self.get_shader(&pipeline.rg) { Some((rg, "MyRaygenShader")) } else { None },
any_hit_shader: None,
closest_hit_shader: None,
miss_shader: None,
intersection_shader: None,
callable_shader: None,
shaders,
pipeline_layout: pipeline.pipeline_layout.clone()
});
}
Expand Down

0 comments on commit b33efce

Please sign in to comment.