最後活躍 1731289879

Jina-15.py 原始檔案
1@requests(on='/stream')
2async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument:
3 input = tokenizer(doc.prompt, return_tensors='pt')
4 input_len = input['input_ids'].shape[1]
5 for _ in range(doc.max_tokens):
6 output = self.model.generate(**input, max_new_tokens=1)
7 if output[0][-1] == tokenizer.eos_token_id:
8 break
9 yield ModelOutputDocument(
10 token_id=output[0][-1],
11 generated_text=tokenizer.decode(
12 output[0][input_len:], skip_special_tokens=True
13 ),
14 )
15 input = {
16 'input_ids': output,
17 'attention_mask': torch.ones(1, len(output[0])),
18 }