Skip to content

Commit

Permalink
finalized
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah-Wilderom committed May 18, 2024
1 parent 6437b27 commit 26990b0
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 112 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ test:
clean:
rm -rf 3000_network
rm -rf 4000_network
rm -rf 5000_network
77 changes: 35 additions & 42 deletions crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"io"
)

func generateId() string {
buf := make([]byte, 32)
io.ReadFull(rand.Reader, buf)
return hex.EncodeToString(buf)
}

func hashKey(key string) string {
hash := sha256.Sum256([]byte(key))
return hex.EncodeToString(hash[:])
}

func newEncryptionKey() []byte {
keyBuf := make([]byte, 32)
io.ReadFull(rand.Reader, keyBuf)
return keyBuf
}

func copyDecrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
block, err := aes.NewCipher(key)
if err != nil {
return 0, err
}

// Read the IV from the given io.Reader which,
// in our case should be the block.BlockSize() bytes we read.
iv := make([]byte, block.BlockSize())
if _, err := src.Read(iv); err != nil {
return 0, err
}

func copyStream(stream cipher.Stream, blockSize int, src io.Reader, dst io.Writer) (int, error) {
var (
buf = make([]byte, 32*1024)
stream = cipher.NewCTR(block, iv)
nw = block.BlockSize()
buf = make([]byte, 32*1024)
nw = blockSize
)

for {
Expand All @@ -56,6 +56,23 @@ func copyDecrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
return nw, nil
}

func copyDecrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
block, err := aes.NewCipher(key)
if err != nil {
return 0, err
}

// Read the IV from the given io.Reader which,
// in our case should be the block.BlockSize() bytes we read.
iv := make([]byte, block.BlockSize())
if _, err := src.Read(iv); err != nil {
return 0, err
}

stream := cipher.NewCTR(block, iv)
return copyStream(stream, block.BlockSize(), src, dst)
}

func copyEncrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
block, err := aes.NewCipher(key)
if err != nil {
Expand All @@ -72,30 +89,6 @@ func copyEncrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
return 0, err
}

var (
buf = make([]byte, 32*1024)
stream = cipher.NewCTR(block, iv)
nw = block.BlockSize()
)

for {
n, err := src.Read(buf)
if n > 0 {
stream.XORKeyStream(buf, buf[:n])
nn, err := dst.Write(buf[:n])
if err != nil {
return 0, err
}
nw += nn
}

if err == io.EOF {
break
}
if err != nil {
return 0, err
}
}

return nw, nil
stream := cipher.NewCTR(block, iv)
return copyStream(stream, block.BlockSize(), src, dst)
}
44 changes: 28 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"bytes"
"fmt"
"io"
"log"
"os"
"time"
Expand Down Expand Up @@ -44,29 +45,40 @@ func main() {

s1 := makeServer(fmt.Sprintf(":%s", os.Getenv("APP_PORT")), "")
s2 := makeServer(":4000", ":3000")
s3 := makeServer(":5000", ":3000", ":4000")

go func() {
log.Fatal(s1.Start())
}()
go func() {
log.Fatal(s2.Start())
}()

time.Sleep(2 * time.Second)

go s2.Start()
go s3.Start()
time.Sleep(2 * time.Second)

data := bytes.NewReader([]byte("my big data file here"))
s2.Store("myprivatedata", data)
time.Sleep(5 * time.Millisecond)

// r, err := s2.Get("myprivatedata")
// if err != nil {
// log.Fatal(err)
// }
//
// b, err := io.ReadAll(r)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Println(string(b))
for i := 0; i < 20; i++ {
key := fmt.Sprintf("picture_%d.png", i)
data := bytes.NewReader([]byte("my big data file here"))
s3.Store(key, data)
time.Sleep(5 * time.Millisecond)

if err := s3.store.Delete(s3.Id, key); err != nil {
log.Fatal(err)
}

r, err := s3.Get(key)
if err != nil {
log.Fatal(err)
}

b, err := io.ReadAll(r)
if err != nil {
log.Fatal(err)
}

fmt.Println(string(b))
}
}
58 changes: 29 additions & 29 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
)

type FileServerOpts struct {
Id string
EncKey []byte
StorageRoot string
PathTransformFunc PathTransformFunc
Expand All @@ -37,6 +38,10 @@ func NewFileServer(opts FileServerOpts) *FileServer {
PathTransformFunc: opts.PathTransformFunc,
}

if len(opts.Id) == 0 {
opts.Id = generateId()
}

return &FileServer{
FileServerOpts: opts,
store: NewStore(storeOpts),
Expand All @@ -45,16 +50,6 @@ func NewFileServer(opts FileServerOpts) *FileServer {
}
}

func (s *FileServer) stream(msg *Message) error {
peers := []io.Writer{}
for _, peer := range s.peers {
peers = append(peers, peer)
}

mw := io.MultiWriter(peers...)
return gob.NewEncoder(mw).Encode(msg)
}

func (s *FileServer) broadcast(msg *Message) error {
buf := new(bytes.Buffer)
if err := gob.NewEncoder(buf).Encode(msg); err != nil {
Expand All @@ -76,26 +71,29 @@ type Message struct {
}

type MessageStoreFile struct {
Id string
Key string
Size int64
}

type MessageGetFile struct {
Id string
Key string
}

func (s *FileServer) Get(key string) (io.Reader, error) {
if s.store.Has(key) {
if s.store.Has(s.Id, key) {
fmt.Printf("[%s] serving file (%s) from local disk\n", s.Transport.Addr(), key)
_, r, err := s.store.Read(key)
_, r, err := s.store.Read(s.Id, key)
return r, err
}

fmt.Printf("[%s] dont have file (%s) locally, fetching from network...\n", s.Transport.Addr(), key)

msg := Message{
Payload: MessageGetFile{
Key: key,
Id: s.Id,
Key: hashKey(key),
},
}

Expand All @@ -114,7 +112,7 @@ func (s *FileServer) Get(key string) (io.Reader, error) {
return nil, err
}

n, err := s.store.Write(key, io.LimitReader(peer, fileSize))
n, err := s.store.WriteDecrypt(s.Id, s.EncKey, key, io.LimitReader(peer, fileSize))
if err != nil {
return nil, err
}
Expand All @@ -124,7 +122,7 @@ func (s *FileServer) Get(key string) (io.Reader, error) {
peer.CloseStream()
}

_, r, err := s.store.Read(key)
_, r, err := s.store.Read(s.Id, key)
return r, err
}

Expand All @@ -134,14 +132,15 @@ func (s *FileServer) Store(key string, r io.Reader) error {
tee = io.TeeReader(r, fileBuffer)
)

size, err := s.store.Write(key, tee)
size, err := s.store.Write(s.Id, key, tee)
if err != nil {
return err
}

msg := Message{
Payload: MessageStoreFile{
Key: key,
Id: s.Id,
Key: hashKey(key),
Size: size + 16,
},
}
Expand All @@ -152,16 +151,17 @@ func (s *FileServer) Store(key string, r io.Reader) error {

time.Sleep(5 * time.Millisecond)

// TODO: use a multiwriter here.
peers := []io.Writer{}
for _, peer := range s.peers {
peer.Send([]byte{p2p.IncomingStream})
n, err := copyEncrypt(s.EncKey, fileBuffer, peer)
if err != nil {
return err
}

fmt.Println("received and written bytes to disk: ", n)
peers = append(peers, peer)
}
mw := io.MultiWriter(peers...)
mw.Write([]byte{p2p.IncomingStream})
n, err := copyEncrypt(s.EncKey, fileBuffer, mw)
if err != nil {
return err
}
fmt.Printf("[%s] received and written (%d) bytes to disk\n", s.Transport.Addr(), n)

return nil
}
Expand Down Expand Up @@ -216,13 +216,13 @@ func (s *FileServer) handleMessage(from string, msg *Message) error {
}

func (s *FileServer) handleMessageGetFile(from string, msg MessageGetFile) error {
if !s.store.Has(msg.Key) {
if !s.store.Has(msg.Id, msg.Key) {
return fmt.Errorf("[%s] need to serve file (%s) but it does not exists on disk", s.Transport.Addr(), msg.Key)
}

fmt.Printf("[%s] serving file (%s) over the network\n", s.Transport.Addr(), msg.Key)

fileSize, r, err := s.store.Read(msg.Key)
fileSize, r, err := s.store.Read(msg.Id, msg.Key)
if err != nil {
return err
}
Expand Down Expand Up @@ -258,7 +258,7 @@ func (s *FileServer) handleMessageStoreFile(from string, msg MessageStoreFile) e
return fmt.Errorf("peer (%s) could not be found in the peer list", from)
}

n, err := s.store.Write(msg.Key, io.LimitReader(peer, msg.Size))
n, err := s.store.Write(msg.Id, msg.Key, io.LimitReader(peer, msg.Size))
if err != nil {
return err
}
Expand All @@ -277,7 +277,7 @@ func (s *FileServer) bootstrapNetwork() error {
}

go func(addr string) {
log.Println("attempting to connect with remote:", addr)
log.Printf("[%s] attempting to connect with remote %s\n", s.Transport.Addr(), addr)
if err := s.Transport.Dial(addr); err != nil {
log.Println("dial error:", err)
}
Expand Down
Loading

0 comments on commit 26990b0

Please sign in to comment.