Dernière activité 1728856729

Burn-readme-4.rs Brut
1use burn::nn;
2use burn::module::Module;
3use burn::tensor::backend::Backend;
4
5#[derive(Module, Debug)]
6pub struct PositionWiseFeedForward<B: Backend> {
7 linear_inner: nn::Linear<B>,
8 linear_outer: nn::Linear<B>,
9 dropout: nn::Dropout,
10 gelu: nn::Gelu,
11}
12
13impl<B: Backend> PositionWiseFeedForward<B> {
14 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
15 let x = self.linear_inner.forward(input);
16 let x = self.gelu.forward(x);
17 let x = self.dropout.forward(x);
18
19 self.linear_outer.forward(x)
20 }
21}