Jina-15.py
· 695 B · Python
Неформатований
@requests(on='/stream')
async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument:
input = tokenizer(doc.prompt, return_tensors='pt')
input_len = input['input_ids'].shape[1]
for _ in range(doc.max_tokens):
output = self.model.generate(**input, max_new_tokens=1)
if output[0][-1] == tokenizer.eos_token_id:
break
yield ModelOutputDocument(
token_id=output[0][-1],
generated_text=tokenizer.decode(
output[0][input_len:], skip_special_tokens=True
),
)
input = {
'input_ids': output,
'attention_mask': torch.ones(1, len(output[0])),
}
1 | @requests(on='/stream') |
2 | async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument: |
3 | input = tokenizer(doc.prompt, return_tensors='pt') |
4 | input_len = input['input_ids'].shape[1] |
5 | for _ in range(doc.max_tokens): |
6 | output = self.model.generate(**input, max_new_tokens=1) |
7 | if output[0][-1] == tokenizer.eos_token_id: |
8 | break |
9 | yield ModelOutputDocument( |
10 | token_id=output[0][-1], |
11 | generated_text=tokenizer.decode( |
12 | output[0][input_len:], skip_special_tokens=True |
13 | ), |
14 | ) |
15 | input = { |
16 | 'input_ids': output, |
17 | 'attention_mask': torch.ones(1, len(output[0])), |
18 | } |