Overwriting shaders
The code described in this section can be run from the
overwriting_shaders.rs example with
cargo run --features derive --example overwriting_shaders
. Complete source code is also provided at the
end of this page.
It is possible to replace the WGSL sources of any shader by another source file (assuming it exports the same symbols and/or compute kernels). This can be useful in particular if you need to inject some special behavior into a particular WGSL function, fix a bug, or improve performances, in a shader belonging to a dependency you don’t have control over. It can also be used to overwrite any embedded shader with a source file available separately from your executable (for example if some of your game shaders are bundled in a separate assets directory).
Given a shader deriving the Shader
trait with derive(Shader)
, overwriting its sources
only require a single line of code:
MyShader::set_wgsl_path("path/to/your/shader.wgsl");
Since this is static method, this will set the wgsl path globally. In order to take this change into account, any
compute pipeline that was constructed from this shader (or that depends directly or indirectly on this shader) have to
be re-created; subsequent calls to Shader::from_device
, as well as any other function provided by the Shader
trait,
will automatically take the new path into account. Typically, you would set this path before constructing any compute pipeline.
The shader at the path set with set_wgsl_path
will not be embedded in the built executable or library. So be sure
that the path is always valid relative to wherever your executable will be run. Such shaders will typically be deployed
alongside the executable in a known path.
All the overwritten shader paths are registered globally in
wgcore::shader::ShaderRegistry
.
This is essentially a hash-map between a shader type and its overwritten path. Another way to obtain the known path
for a shader is by calling Shader::wgsl_path
.
Complete example
- main.rs
- kernel.wgsl
#[cfg(not(feature = "derive"))]
std::compile_error!(
r#"
###############################################################
## The `derive` feature must be enabled to run this example. ##
###############################################################
"#
);
use nalgebra::DVector;
use std::fmt::Debug;
use wgcore::gpu::GpuInstance;
use wgcore::kernel::{KernelInvocationBuilder, KernelInvocationQueue};
use wgcore::tensor::GpuVector;
use wgcore::Shader;
use wgpu::{BufferUsages, ComputePipeline};
// Declare our shader module that contains our composable functions.
// Note that we don’t build any compute pipeline from this wgsl file.
#[derive(Shader)]
#[shader(
src = "compose_dependency.wgsl" // Shader source code, will be embedded in the exe with `include_str!`
)]
struct Composable;
#[derive(Shader)]
#[shader(
derive(Composable), // This shader depends on the `Composable` shader.
src = "compose_kernel.wgsl", // Shader source code, will be embedded in the exe with `include_str!`.
composable = false // This shader doesn’t export any symbols reusable from other wgsl shaders.
)]
struct WgKernel {
// This ComputePipeline field indicates that the Shader macro needs to generate the boilerplate
// for loading the compute pipeline in `WgKernel::from_device`.
main: ComputePipeline,
}
#[derive(Copy, Clone, PartialEq, Default, bytemuck::Pod, bytemuck::Zeroable)]
#[repr(C)]
pub struct MyStruct {
value: f32,
}
// Optional: makes the debug output more concise.
impl Debug for MyStruct {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}
#[async_std::main]
async fn main() -> anyhow::Result<()> {
// Initialize the gpu device and its queue.
//
// Note that `GpuInstance` is just a simple helper struct for initializing the gpu resources.
// You are free to initialize them independently if more control is needed, or reuse the ones
// that were already created/owned by e.g., a game engine.
let gpu = GpuInstance::new().await?;
// Load and compile our kernel. The `from_device` function was generated by the `Shader` derive.
// Note that its dependency to `Composable` is automatically resolved by the `Shader` derive
// too.
let kernel_before_overwrite = WgKernel::from_device(gpu.device())?;
// Run the original shader.
let result_before_overwrite = run_kernel(&gpu, &kernel_before_overwrite).await;
// Overwrite the sources of the dependency module.
// Since we are running this with `cargo run --example`, the path is relative to the
// `target/debug` folder.
Composable::set_wgsl_path("../../crates/wgcore/examples/overwritten_dependency.wgsl");
// Recompile our kernel.
let kernel_after_overwrite = WgKernel::from_device(gpu.device())?;
// Run the modified kernel.
let result_after_overwrite = run_kernel(&gpu, &kernel_after_overwrite).await;
println!("Result before overwrite: {:?}", result_before_overwrite);
println!("Result after overwrite: {:?}", result_after_overwrite);
Ok(())
}
async fn run_kernel(gpu: &GpuInstance, kernel: &WgKernel) -> Vec<MyStruct> {
// Create the buffers.
const LEN: u32 = 10;
let a_data = DVector::from_fn(LEN as usize, |i, _| MyStruct { value: i as f32 });
let b_data = DVector::from_fn(LEN as usize, |i, _| MyStruct {
value: i as f32 * 10.0,
});
let a_buf = GpuVector::init(
gpu.device(),
&a_data,
BufferUsages::STORAGE | BufferUsages::COPY_SRC,
);
let b_buf = GpuVector::init(gpu.device(), &b_data, BufferUsages::STORAGE);
let staging = GpuVector::uninit(
gpu.device(),
LEN,
BufferUsages::COPY_DST | BufferUsages::MAP_READ,
);
// Queue the operation.
let mut queue = KernelInvocationQueue::new(gpu.device());
KernelInvocationBuilder::new(&mut queue, &kernel.main)
.bind0([a_buf.buffer(), b_buf.buffer()])
.queue(LEN.div_ceil(64));
// Encode & submit the operation to the gpu.
let mut encoder = gpu.device().create_command_encoder(&Default::default());
queue.encode(&mut encoder, None);
// Copy the result to the staging buffer.
staging.copy_from(&mut encoder, &a_buf);
gpu.queue().submit(Some(encoder.finish()));
// Read the result back from the gpu.
staging
.read(gpu.device())
.await
.expect("Failed to read result from the GPU.")
}
#define_import_path composable::module
struct MyStruct {
value: f32,
}
fn shared_function(a: MyStruct, b: MyStruct) -> MyStruct {
// Same as compose_dependency.wgsl but with a subtraction instead of an addition.
return MyStruct(a.value - b.value);
}