Dernière activité 1731289879

knox a révisé ce gist 1731289879. Aller à la révision

1 file changed, 18 insertions

Jina-15.py(fichier créé)

@@ -0,0 +1,18 @@
1 + @requests(on='/stream')
2 + async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument:
3 + input = tokenizer(doc.prompt, return_tensors='pt')
4 + input_len = input['input_ids'].shape[1]
5 + for _ in range(doc.max_tokens):
6 + output = self.model.generate(**input, max_new_tokens=1)
7 + if output[0][-1] == tokenizer.eos_token_id:
8 + break
9 + yield ModelOutputDocument(
10 + token_id=output[0][-1],
11 + generated_text=tokenizer.decode(
12 + output[0][input_len:], skip_special_tokens=True
13 + ),
14 + )
15 + input = {
16 + 'input_ids': output,
17 + 'attention_mask': torch.ones(1, len(output[0])),
18 + }
Plus récent Plus ancien