Skip to content

Commit

Permalink
[manager] fix address when binding to 0 (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Jan 10, 2025
1 parent b617bd2 commit 6b3665a
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Manager {
py: Python<'_>,
replica_id: String,
lighthouse_addr: String,
address: String,
hostname: String,
bind: String,
store_addr: String,
world_size: u64,
Expand All @@ -52,7 +52,7 @@ impl Manager {
.block_on(manager::Manager::new(
replica_id,
lighthouse_addr,
address,
hostname,
bind,
store_addr,
world_size,
Expand Down
24 changes: 10 additions & 14 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use gethostname::gethostname;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -53,7 +52,7 @@ struct ManagerState {
pub struct Manager {
replica_id: String,
lighthouse_addr: String,
address: String,
hostname: String,
store_address: String,
world_size: u64,
state: Mutex<ManagerState>,
Expand All @@ -80,19 +79,20 @@ impl Manager {
pub async fn new(
replica_id: String,
lighthouse_addr: String,
address: String,
hostname: String,
bind: String,
store_addr: String,
world_size: u64,
) -> Result<Arc<Self>> {
let listener = tokio::net::TcpListener::bind(&bind).await?;
let local_addr = listener.local_addr()?;

let (should_commit_tx, _) = broadcast::channel(16);

Ok(Arc::new(Self {
replica_id: replica_id,
lighthouse_addr: lighthouse_addr,
address: address,
hostname: hostname,
store_address: store_addr,
world_size: world_size,
state: Mutex::new(ManagerState {
Expand All @@ -103,7 +103,7 @@ impl Manager {
should_commit_count: HashSet::new(),
should_commit_failures: HashSet::new(),
}),
local_addr: listener.local_addr()?,
local_addr: local_addr,
listener: Mutex::new(Some(listener)),
}))
}
Expand All @@ -122,11 +122,7 @@ impl Manager {
}

pub fn address(&self) -> String {
format!(
"http://{}:{}",
gethostname().into_string().unwrap(),
self.local_addr.port()
)
format!("http://{}:{}", self.hostname, self.local_addr.port())
}

async fn _run_grpc(self: Arc<Self>) -> Result<()> {
Expand Down Expand Up @@ -228,7 +224,7 @@ impl ManagerService for Arc<Manager> {
room_id: room_id.clone(),
requester: Some(QuorumMember {
replica_id: self.replica_id.clone(),
address: self.address.clone(),
address: self.address(),
store_address: self.store_address.clone(),
step: req.step,
world_size: self.world_size,
Expand Down Expand Up @@ -470,7 +466,7 @@ mod tests {
let manager = Manager::new(
"rep_id".to_string(),
lighthouse.address(),
"addr".to_string(),
"localhost".to_string(),
"[::]:0".to_string(),
"store_addr".to_string(),
1, // world size
Expand All @@ -493,7 +489,7 @@ mod tests {
lighthouse_fut.abort();

assert_eq!(resp.quorum_id, 1);
assert_eq!(resp.address, "addr".to_string());
assert_eq!(resp.address, manager.address());
assert_eq!(resp.store_address, "store_addr".to_string());
assert_eq!(resp.max_step, 123);
assert_eq!(resp.max_rank, Some(0));
Expand Down Expand Up @@ -525,7 +521,7 @@ mod tests {
let manager = Manager::new(
format!("rep_{}", replica_id),
lighthouse_addr,
"addr".to_string(),
"localhost".to_string(),
"[::]:0".to_string(),
"store_addr".to_string(),
1, // world size
Expand Down
7 changes: 3 additions & 4 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
lighthouse_addr: Optional[str] = None,
replica_id: Optional[str] = None,
port: Optional[int] = None,
hostname: str = socket.gethostname(),
) -> None:
"""
Args:
Expand All @@ -122,6 +123,7 @@ def __init__(
store_port: TCPStore port for this replica group
lighthouse_addr: if rank==0, the address of the lighthouse server
replica_id: if rank==0, the replica_id for this group
hostname: if rank==0, the hostname to advertise to the lighthouse server
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
Expand Down Expand Up @@ -159,12 +161,9 @@ def _manager_state_dict() -> Dict[str, T]:
self._manager: Optional[_Manager] = None

if rank == 0:
hostname = socket.gethostname()

if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))

addr = f"http://{hostname}:{port}"
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

Expand All @@ -174,7 +173,7 @@ def _manager_state_dict() -> Dict[str, T]:
self._manager = _Manager(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
address=addr,
hostname=hostname,
bind=bind,
store_addr=f"{store_addr}:{store_port}",
world_size=world_size,
Expand Down
10 changes: 8 additions & 2 deletions torchft/torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Manager:
self,
replica_id: str,
lighthouse_addr: str,
address: str,
hostname: str,
bind: str,
store_addr: str,
world_size: int,
Expand All @@ -36,6 +36,12 @@ class Manager:
def shutdown(self) -> None: ...

class Lighthouse:
def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None) -> None: ...
def __init__(
self,
bind: str,
min_replicas: int,
join_timeout_ms: Optional[int] = None,
quorum_tick_ms: Optional[int] = None,
) -> None: ...
def address(self) -> str: ...
def shutdown(self) -> None: ...

0 comments on commit 6b3665a

Please sign in to comment.