Skip to content

Commit

Permalink
Merge branch 'development' into feature/add-vulns-audit
Browse files Browse the repository at this point in the history
  • Loading branch information
caverav committed Oct 14, 2024
2 parents 554c087 + 49fa2f5 commit 5830582
Show file tree
Hide file tree
Showing 29 changed files with 1,634 additions and 119 deletions.
44 changes: 44 additions & 0 deletions cwe_api/inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from transformers import TextClassificationPipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from torch import nn
import json
import numpy as np

ID2LABEL_PATH = './id2label.json'
LABEL2ID_PATH = './label2id.json'
MODEL_PATH = "./modelo_cwe/checkpoint-141693"
NUMBER_OF_PREDICTIONS = 3

class BestCweClassifications(TextClassificationPipeline):
def postprocess(self, model_outputs):
best_class = model_outputs["logits"]
return best_class

def inferencer(vuln):

with open(ID2LABEL_PATH) as f:
id2label = json.load(f)

with open(LABEL2ID_PATH) as f:
label2id = json.load(f)

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-multilingual-cased")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id)

m = nn.Softmax(dim=1)

pipe = BestCweClassifications(model=model, tokenizer=tokenizer)
output = pipe(vuln, batch_size=2, truncation="only_first")

softmax_output = m(output[0])[0]
ind = np.argpartition(softmax_output, -NUMBER_OF_PREDICTIONS)[-NUMBER_OF_PREDICTIONS:]

reversed_indices = np.flip(ind.numpy(),0).copy()
score = softmax_output[reversed_indices]

return [{'priority': i, 'label': id2label[str(reversed_indices[i])], 'score': float(score[i].numpy())} for i in range(0, 3)]

7 changes: 3 additions & 4 deletions cwe_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from inferencer import inferencer

app = FastAPI()

Expand All @@ -14,15 +15,13 @@
allow_headers=["*"], # Cabeceras permitidas
)

classifier = pipeline(task='text-classification', model="modelo_cwe/checkpoint-20790")

class VulnerabilityRequest(BaseModel):
vuln: str

@app.post("/classify")
async def classify_vulnerability(vuln_request: VulnerabilityRequest):
vuln = vuln_request.vuln
result = classifier(vuln)
result = inferencer(vuln)
return {"result": result}

@app.get("/")
Expand All @@ -31,7 +30,7 @@ async def read_root():
"Los dispositivos de CPU Siemens SIMATIC S7-300 permiten a los atacantes remotos causar una denegación de servicio "
"(transición de modo de defecto) a través de paquetes elaborados en (1) puerto TCP 102 o (2) Profibus."
)
result = classifier(example_vuln)
result = inferencer(example_vuln)
return {"example_vuln": example_vuln, "result": result}

if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@
"@radix-ui/react-switch": "^1.1.0",
"@types/react-router-dom": "^5.3.3",
"ae-cvss-calculator": "^1.0.2",
"chart.js": "^4.4.4",
"chartjs-plugin-annotation": "^3.0.1",
"check-password-strength": "^2.0.10",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"dayjs": "^1.11.7",
"html2canvas": "^1.4.1",
"i18next": "^23.12.2",
"i18next-browser-languagedetector": "^8.0.0",
"jspdf": "^2.5.2",
"lucide-react": "^0.435.0",
"react": "^18.3.1",
"react-chartjs-2": "^5.2.0",
"react-dom": "^18.3.1",
"react-i18next": "^15.0.0",
"react-quill-new": "^3.3.0",
Expand Down
161 changes: 161 additions & 0 deletions frontend/src/components/dashboard/AverageCVSS.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import { Cvss3P1 } from 'ae-cvss-calculator';
import {
BarElement,
CategoryScale,
Chart as ChartJS,
ChartOptions,
Legend,
LinearScale,
Title,
Tooltip,
} from 'chart.js';
import annotationPlugin from 'chartjs-plugin-annotation';
import React, { useEffect, useState } from 'react';
import { Bar } from 'react-chartjs-2';
import { useParams } from 'react-router-dom';

import { getAuditById } from '@/services/audits';

ChartJS.register(
CategoryScale,
LinearScale,
BarElement,
Title,
Tooltip,
Legend,
annotationPlugin,
);

const cvssStringToScore = (cvssScore: string) => {
try {
const cvssVector = new Cvss3P1(cvssScore);
return cvssVector.calculateExactOverallScore();
} catch (error) {
console.error('Invalid CVSS vector:', error);
}
return 0;
};

type AverageCVSSProps = {
auditId?: string;
};

const AverageCVSS: React.FC<AverageCVSSProps> = ({ auditId }) => {
const paramId = useParams().auditId;
if (!auditId) {
auditId = paramId;
}
const [averageCVSS, setAverageCVSS] = useState(0);
const [data, setData] = useState({
labels: [''],
datasets: [
{
data: [0],
backgroundColor: '#3498db',
},
],
});
useEffect(() => {
if (auditId === undefined) {
auditId = paramId;
}
getAuditById(auditId)
.then(audit => {
setAverageCVSS(
Math.round(
(audit.datas.findings.reduce(
(acc, finding) => acc + cvssStringToScore(finding.cvssv3),
0,
) /
audit.datas.findings.length) *
10,
) / 10,
);
setData({
labels: audit.datas.findings.map(finding => finding.title),
datasets: [
{
data: audit.datas.findings.map(finding =>
cvssStringToScore(finding.cvssv3),
),
// eslint-disable-next-line @typescript-eslint/consistent-type-assertions
backgroundColor: audit.datas.findings.map(finding =>
cvssStringToScore(finding.cvssv3) >= 9
? '#FF4136'
: cvssStringToScore(finding.cvssv3) >= 7
? '#FF851B'
: cvssStringToScore(finding.cvssv3) >= 4
? '#FFDC00'
: '#2ECC40',
) as unknown as string,
},
],
});
})
.catch(console.error);
}, [auditId, averageCVSS]);

const options: ChartOptions<'bar'> = {
indexAxis: 'y',
responsive: true,
maintainAspectRatio: false,
layout: {
padding: {
top: 30,
},
},
scales: {
x: {
beginAtZero: true,
max: 10,
ticks: {
stepSize: 2,
color: 'white',
},
grid: {
color: 'rgba(255, 255, 255, 0.1)',
},
},
y: {
ticks: {
color: 'white',
},
grid: {
display: false,
},
},
},
plugins: {
legend: {
display: false,
},
annotation: {
annotations: {
line1: {
type: 'line',
xMin: averageCVSS,
xMax: averageCVSS,
borderColor: '#2ecc71',
borderWidth: 2,
borderDash: [5, 5],
},
},
},
},
};

return (
<div className="bg-gray-800 rounded-lg p-6">
<div className="relative">
<div className="absolute top-0 left-0 w-full text-right pr-4 text-green-400 text-sm">
Average CVSS: {averageCVSS}
</div>
<div className="chart-container" style={{ height: '400px' }}>
<Bar data={data} options={options} />
</div>
</div>
</div>
);
};

export default AverageCVSS;
Loading

0 comments on commit 5830582

Please sign in to comment.