Skip to content

Commit

Permalink
Supercoder Issue 1358 Resolved (#1426)
Browse files Browse the repository at this point in the history
* Update user.py issue 1358

* Update auth.py issue 1358
  • Loading branch information
supercoder-dev authored Dec 5, 2024
1 parent 6c816d2 commit 24578a8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
21 changes: 14 additions & 7 deletions superagi/controllers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ class Config:

# CRUD Operations
@router.post("/add", response_model=UserOut, status_code=201)
def create_user(user: UserIn,
Authorize: AuthJWT = Depends(check_auth)):
def create_user(user: UserIn, Authorize: AuthJWT = Depends(check_auth)):
"""
Create a new user.
Expand All @@ -62,21 +61,29 @@ def create_user(user: UserIn,
Raises:
HTTPException (status_code=400): If there is an issue creating the user.
HTTPException (status_code=422): If required fields are missing or incorrectly formatted.
"""

logger.info("Received user data: %s", user)

# Validate incoming request data
if not user.name or not user.email or not user.password:
logger.error("Missing required fields: name, email, or password")
raise HTTPException(status_code=422, detail="Missing required fields: name, email, or password")

db_user = db.session.query(User).filter(User.email == user.email).first()
if db_user:
return db_user

db_user = User(name=user.name, email=user.email, password=user.password, organisation_id=user.organisation_id)
db.session.add(db_user)
db.session.commit()
db.session.flush()

organisation = Organisation.find_or_create_organisation(db.session, db_user)
Project.find_or_create_default_project(db.session, organisation.id)
logger.info("User created", db_user)

#adding local llm configuration
logger.info("User created: %s", db_user)
# Adding local LLM configuration
ModelsConfig.add_llm_config(db.session, organisation.id)

return db_user
Expand Down
18 changes: 13 additions & 5 deletions superagi/helper/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,28 @@ def get_user_organisation(Authorize: AuthJWT = Depends(check_auth)):
return organisation


def get_current_user(Authorize: AuthJWT = Depends(check_auth)):
def get_current_user(Authorize: AuthJWT = Depends(check_auth), request: Request = Depends()):
env = get_config("ENV", "DEV")

if env == "DEV":
email = "[email protected]"
else:
# Retrieve the email of the logged-in user from the JWT token payload
email = Authorize.get_jwt_subject()
# Check for HTTP basic auth headers
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Basic '):
import base64
auth_decoded = base64.b64decode(auth_header.split(' ')[1]).decode('utf-8')
username, password = auth_decoded.split(':')
# Assuming username is the email
email = username
else:
# Retrieve the email of the logged-in user from the JWT token payload
email = Authorize.get_jwt_subject()

# Query the User table to find the user by their email
user = db.session.query(User).filter(User.email == email).first()
return user


api_key_header = APIKeyHeader(name="X-API-Key")


Expand All @@ -83,4 +91,4 @@ def get_organisation_from_api_key(api_key: str = Security(api_key_header)) -> Or
)

organisation = db.session.query(Organisation).filter(Organisation.id == query_result.org_id).first()
return organisation
return organisation

0 comments on commit 24578a8

Please sign in to comment.