-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Increasing WildBgDataSet speed (#10)
Moving away from pandas dataframe changing to loading data into tensor on init
- Loading branch information
Showing
3 changed files
with
114 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torch.utils.data import Dataset | ||
import torch | ||
|
||
class WildBgDataSet(Dataset): | ||
def __init__(self, csv_files: list | str): | ||
# If you want to combine several CSV files from another folder, use the following: | ||
if isinstance(csv_files, str): | ||
csv_files = [csv_files] | ||
labels = [] | ||
inputs = [] | ||
for path in csv_files: | ||
with open(path, 'r') as f: | ||
lines = f.readlines() | ||
for line in lines[1:]: | ||
line = line.strip().split(';') | ||
line = list(map(float, line)) | ||
labels.append(line[:6]) | ||
inputs.append(line[6:]) | ||
self.inputs = torch.Tensor(inputs) | ||
self.labels = torch.Tensor(labels) | ||
|
||
def __len__(self): | ||
return self.inputs.shape[0] | ||
|
||
def __getitem__(self, idx): | ||
# First 6 columns are outputs, last 202 columns are inputs | ||
return self.inputs[idx], self.labels[idx] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from torch import nn | ||
|
||
class Model(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
# Inputs to hidden layer linear transformation | ||
self.hidden1 = nn.Linear(202, 300) | ||
self.hidden2 = nn.Linear(300, 250) | ||
self.hidden3 = nn.Linear(250, 200) | ||
|
||
# Output layer, 6 outputs for win/lose - normal/gammon/bg | ||
self.output = nn.Linear(200, 6) | ||
|
||
# Define activation function and softmax output | ||
self.activation = nn.ReLU() | ||
self.softmax = nn.Softmax(dim=1) | ||
|
||
def forward(self, x): | ||
# Pass the input tensor through each of our operations | ||
x = self.hidden1(x) | ||
x = self.activation(x) | ||
x = self.hidden2(x) | ||
x = self.activation(x) | ||
x = self.hidden3(x) | ||
x = self.activation(x) | ||
x = self.output(x) | ||
x = self.softmax(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters