knox revised this gist . Go to revision
1 file changed, 18 insertions
Jina-15.py(file created)
@@ -0,0 +1,18 @@ | |||
1 | + | @requests(on='/stream') | |
2 | + | async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument: | |
3 | + | input = tokenizer(doc.prompt, return_tensors='pt') | |
4 | + | input_len = input['input_ids'].shape[1] | |
5 | + | for _ in range(doc.max_tokens): | |
6 | + | output = self.model.generate(**input, max_new_tokens=1) | |
7 | + | if output[0][-1] == tokenizer.eos_token_id: | |
8 | + | break | |
9 | + | yield ModelOutputDocument( | |
10 | + | token_id=output[0][-1], | |
11 | + | generated_text=tokenizer.decode( | |
12 | + | output[0][input_len:], skip_special_tokens=True | |
13 | + | ), | |
14 | + | ) | |
15 | + | input = { | |
16 | + | 'input_ids': output, | |
17 | + | 'attention_mask': torch.ones(1, len(output[0])), | |
18 | + | } |
Newer
Older