-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathinference.py
More file actions
45 lines (37 loc) · 1.59 KB
/
inference.py
File metadata and controls
45 lines (37 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Dict, List
from datasets import Dataset
from vllm import LLM, SamplingParams
from examples.star.utils import generate_prompt, cleanup
def generate_predictions(
model_name: str, dataset: Dataset, temperature: float = 1.0, n: int = 1
) -> List[List[str]]:
"""Generate predictions for a given dataset using a specified language model and
sampling parameters. The function loads the dataset, constructs prompts from
each example, and obtains generated predictions. The resulting predictions are
then added as a new column to the dataset.
Args:
----
model_name (str): Name of the model to use for generation.
dataset (Dataset): The Dataset object.
temperature (float, optional): Temperature setting for the model's
sampling strategy. Default is 1.0.
n (int, optional): Number of sampling runs per prompt. Default is 1.
Returns:
-------
predictions (List[List[str]]): Predictions on the dataset.
"""
sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=512)
llm = LLM(model=model_name)
prompts: List[List[Dict]] = []
for example in dataset:
prompt = example["prompt"]
test = example["test"]
prompt = generate_prompt(prompt, test)
prompts.append([{"role": "user", "content": prompt}])
outputs = llm.chat(prompts, sampling_params)
results: List[List[str]] = []
for output in outputs:
generated_texts = [one.text for one in output.outputs]
results.append(generated_texts)
cleanup(llm, vllm=True)
return results