-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathprompt.py
91 lines (77 loc) · 2.77 KB
/
prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class Prompt:
system_prompt: Optional[str] = "" # The system prompt
role: Optional[str] = "" # The role of the prompt
header: str = "" # Initial header
intermediate: str = "" # Intermediate text
footer: str = "" # Final footer
target: str = "" # Name of the target attribute
original_point: Dict[str, Any] = field(default_factory=dict) # Original point
gt: Optional[str] = None # Ground truth
answer: Optional[str] = None # Answer given to the prompt
shots: List[str] = field(default_factory=list)
id: int = -1
template: str = "{header}\n{shots}\n{intermediate}\n\n{footer}\n\n{answer}"
def get_prompt(self, show_answer=False):
if show_answer:
return self.template.format(
header=self.header,
shots="\n\n".join(self.shots),
intermediate=self.intermediate,
footer=self.footer,
answer=self.gt,
)
else:
return self.template.format(
header=self.header,
shots="\n\n".join(self.shots),
intermediate=self.intermediate,
footer=self.footer,
answer="",
)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, dict) -> "Prompt":
return cls(**dict)
def get_copy(self):
return Prompt(
header=self.header,
intermediate=self.intermediate,
footer=self.footer,
target=self.target,
original_point=self.original_point,
gt=self.gt,
answer=self.answer,
shots=self.shots,
id=self.id,
template=self.template,
)
@dataclass
class Conversation:
system_prompt: str
prompts: List[Prompt]
def __init__(self, system_prompt: str, prompts: List[Prompt]) -> None:
self.system_prompt = system_prompt
for prompt in prompts:
prompt.system_prompt = ""
prompt.template = "{intermediate}"
self.prompts = prompts
def set_system_prompt(self, system_prompt: str) -> "Conversation":
self.system_prompt = system_prompt
return self
def get_copy(self):
return Conversation(
system_prompt=self.system_prompt,
prompts=self.prompts
)
def swap_roles(self, swap_dict: Dict[str, str]):
for prompt in self.prompts:
if prompt.role in swap_dict:
prompt.role = swap_dict[prompt.role]
else:
print("Role not found!")
def add_prompt(self, prompt: Prompt):
self.prompts.append(prompt)