Last active 1731289344

Jina-1.py Raw
1from jina import Executor, requests
2from docarray import DocList, BaseDoc
3from transformers import pipeline
4
5
6class Prompt(BaseDoc):
7 text: str
8
9
10class Generation(BaseDoc):
11 prompt: str
12 text: str
13
14
15class StableLM(Executor):
16 def __init__(self, **kwargs):
17 super().__init__(**kwargs)
18 self.generator = pipeline(
19 'text-generation', model='stabilityai/stablelm-base-alpha-3b'
20 )
21
22 @requests
23 def generate(self, docs: DocList[Prompt], **kwargs) -> DocList[Generation]:
24 generations = DocList[Generation]()
25 prompts = docs.text
26 llm_outputs = self.generator(prompts)
27 for prompt, output in zip(prompts, llm_outputs):
28 generations.append(Generation(prompt=prompt, text=output))
29 return generations