Skip to content

Commit

Permalink
add org
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Sep 20, 2024
1 parent d6b3edb commit cc39788
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
58 changes: 58 additions & 0 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import Memory
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from memgpt.schemas.organization import Organization
from memgpt.schemas.source import Source
from memgpt.schemas.tool import Tool
from memgpt.schemas.user import User
Expand Down Expand Up @@ -134,6 +135,21 @@ def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at)


class OrganizationModel(Base):
__tablename__ = "organizations"
__table_args__ = {"extend_existing": True}

id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))

def __repr__(self) -> str:
return f"<Organization(id='{self.id}' name='{self.name}')>"

def to_record(self) -> Organization:
return Organization(id=self.id, name=self.name, created_at=self.created_at)


class APIKeyModel(Base):
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""

Expand Down Expand Up @@ -515,6 +531,14 @@ def create_user(self, user: User):
session.add(UserModel(**vars(user)))
session.commit()

@enforce_types
def create_organization(self, organization: Organization):
with self.session_maker() as session:
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
raise ValueError(f"Organization with id {organization.id} already exists")
session.add(OrganizationModel(**vars(organization)))
session.commit()

@enforce_types
def create_block(self, block: Block):
with self.session_maker() as session:
Expand Down Expand Up @@ -638,6 +662,16 @@ def delete_user(self, user_id: str):

session.commit()

@enforce_types
def delete_organization(self, organization_id: str):
with self.session_maker() as session:
# delete from organizations table
session.query(OrganizationModel).filter(OrganizationModel.id == organization_id).delete()

# TODO: delete associated data

session.commit()

@enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
Expand Down Expand Up @@ -685,6 +719,30 @@ def get_user(self, user_id: str) -> Optional[User]:
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def get_organization(self, org_id: str) -> Optional[Organization]:
with self.session_maker() as session:
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()

@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
if cursor:
query = query.filter(OrganizationModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
organization_records = [r.to_record() for r in results]
next_cursor = organization_records[-1].id
assert isinstance(next_cursor, str)

return next_cursor, organization_records

@enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
Expand Down
2 changes: 1 addition & 1 deletion memgpt/server/rest_api/routers/v1/organizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_all_orgs(
Get a list of all orgs in the database
"""
try:
next_cursor, orgs = server.ms.list_organization(cursor=cursor, limit=limit)
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message, UpdateMessage
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
Expand Down Expand Up @@ -694,12 +695,27 @@ def create_user(self, request: UserCreate) -> User:
logger.info(f"Created new user from config: {user}")

# add default for the user
# TODO: move to org
assert user.id is not None, f"User id is None: {user}"
self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id)

return user

def create_org(self, request: OrganizationCreate) -> Organization:
"""Create a new org using a config"""
if not request.name:
# auto-generate a name
request.name = create_random_username()
org = Organization(name=request.name)
self.ms.create_org(org)
logger.info(f"Created new org from config: {org}")

# add default for the org
# TODO: add default data

return org

def create_agent(
self,
request: CreateAgent,
Expand Down

0 comments on commit cc39788

Please sign in to comment.