@requests(on='/stream') async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument: input = tokenizer(doc.prompt, return_tensors='pt') input_len = input['input_ids'].shape[1] for _ in range(doc.max_tokens): output = self.model.generate(**input, max_new_tokens=1) if output[0][-1] == tokenizer.eos_token_id: break yield ModelOutputDocument( token_id=output[0][-1], generated_text=tokenizer.decode( output[0][input_len:], skip_special_tokens=True ), ) input = { 'input_ids': output, 'attention_mask': torch.ones(1, len(output[0])), }