-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModel_configuration.py
103 lines (85 loc) · 2.91 KB
/
Model_configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import numpy as np
from imgaug import augmenters as iaa
from PIL import Image
import nlpaug.augmenter.word as naw
from tsaug import TimeWarp, Crop, Quantize, Drift, Reverse
from dataclasses import dataclass
from typing import Optional
import logging
import torch
from cpufeature import CPUFeature
from petals.constants import PUBLIC_INITIAL_PEERS
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
@dataclass
class ModelInfo:
repo: str
adapter: Optional[str] = None
MODELS = [
ModelInfo(repo="meta-llama/Llama-2-70b-chat-hf"),
ModelInfo(repo="stabilityai/StableBeluga2"),
ModelInfo(repo="enoch/llama-65b-hf"),
ModelInfo(repo="enoch/llama-65b-hf", adapter="timdettmers/guanaco-65b"),
ModelInfo(repo="bigscience/bloomz"),
# Add more models here
ModelInfo(repo="roda-1"),
ModelInfo(repo="kubu-hai.model.h5-2", adapter="kubu-hai.model.mat-2"),
]
DEFAULT_MODEL_NAME = "enoch/llama-65b-hf"
INITIAL_PEERS = PUBLIC_INITIAL_PEERS
# Set this to a list of multiaddrs to connect to a private swarm instead of the public one, for example:
# INITIAL_PEERS = ['/ip4/10.1.2.3/tcp/31234/p2p/QmcXhze98AcgGQDDYna23s4Jho96n8wkwLJv78vxtFNq44']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
TORCH_DTYPE = "auto"
elif CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]:
TORCH_DTYPE = torch.bfloat16
else:
TORCH_DTYPE = torch.float32 # You can use bfloat16 in this case too, but it will be slow
STEP_TIMEOUT = 10 * 60 # Changed from 5 minutes to 10 minutes
MAX_SESSIONS = 50 # Has effect only for API v1 (HTTP-based)
logger.info("Configuration setup complete.")
# Example preprocess and postprocess functions
def preprocess(data):
logger.debug("Preprocessing data")
# Add your preprocessing steps here
return data
def postprocess(data):
logger.debug("Postprocessing data")
# Add your postprocessing steps here
return data
# Example model class
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# Define your model layers here
def forward(self, x):
# Define the forward pass
return x
# Initialize model
model = MyModel().to(DEVICE)
# Example hybrid function
def hybrid_function(data):
# Preprocessing on CPU
data_cpu = data.to("cpu")
preprocessed_data = preprocess(data_cpu)
# Move data to GPU for inference if available
preprocessed_data = preprocessed_data.to(DEVICE)
output = model(preprocessed_data)
# Postprocessing on CPU
output_cpu = output.to("cpu")
result = postprocess(output_cpu)
return result
# Example usage
data = torch.randn(100, 10).to(DEVICE) # Example data
result = hybrid_function(data)
logger.info("Processing complete.")