use burn::nn; use burn::module::Module; use burn::tensor::backend::Backend; #[derive(Module, Debug)] pub struct PositionWiseFeedForward { linear_inner: nn::Linear, linear_outer: nn::Linear, dropout: nn::Dropout, gelu: nn::Gelu, } impl PositionWiseFeedForward { pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input); let x = self.gelu.forward(x); let x = self.dropout.forward(x); self.linear_outer.forward(x) } }