knox hat die Gist bearbeitet . Zu Änderung gehen
1 file changed, 21 insertions
Burn-readme-4.rs(Datei erstellt)
@@ -0,0 +1,21 @@ | |||
1 | + | use burn::nn; | |
2 | + | use burn::module::Module; | |
3 | + | use burn::tensor::backend::Backend; | |
4 | + | ||
5 | + | #[derive(Module, Debug)] | |
6 | + | pub struct PositionWiseFeedForward<B: Backend> { | |
7 | + | linear_inner: nn::Linear<B>, | |
8 | + | linear_outer: nn::Linear<B>, | |
9 | + | dropout: nn::Dropout, | |
10 | + | gelu: nn::Gelu, | |
11 | + | } | |
12 | + | ||
13 | + | impl<B: Backend> PositionWiseFeedForward<B> { | |
14 | + | pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> { | |
15 | + | let x = self.linear_inner.forward(input); | |
16 | + | let x = self.gelu.forward(x); | |
17 | + | let x = self.dropout.forward(x); | |
18 | + | ||
19 | + | self.linear_outer.forward(x) | |
20 | + | } | |
21 | + | } |
Neuer
Älter