Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace Box by Arc<RwLock> && add get_model, get_mut_model for Model.… #44

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "casbin"
version = "0.1.2"
version = "0.1.3"
authors = ["Joey <[email protected]>"]
edition = "2018"
license = "Apache-2.0"
Expand Down
27 changes: 16 additions & 11 deletions src/adapter/file_adapter.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use crate::adapter::Adapter;
use crate::error::Error;
use crate::error::{Error, ModelError};
use crate::model::Model;
use crate::Result;

use std::fs::File;
use std::io::prelude::*;
use std::io::BufReader;
use std::io::{Error as IoError, ErrorKind};

use crate::Result;

pub struct FileAdapter {
pub file_path: String,
}
Expand Down Expand Up @@ -54,21 +53,27 @@ impl Adapter for FileAdapter {
}

let mut tmp = String::new();
let ast_map1 = m.model.get("p").unwrap();
let ast_map1 = m.get_model().get("p").ok_or_else(|| {
Error::ModelError(ModelError::P(
"Missing policy definition in conf file".to_owned(),
))
})?;
for (ptype, ast) in ast_map1 {
for rule in &ast.policy {
for rule in ast.get_policy() {
let s1 = format!("{}, {}\n", ptype, rule.join(","));
tmp += s1.as_str();
}
}

let ast_map2 = m.model.get("g").unwrap();
for (ptype, ast) in ast_map2 {
for rule in &ast.policy {
let s1 = format!("{}, {}\n", ptype, rule.join(","));
tmp += s1.as_str();
if let Some(ast_map2) = m.get_model().get("g") {
for (ptype, ast) in ast_map2 {
for rule in ast.get_policy() {
let s1 = format!("{}, {}\n", ptype, rule.join(","));
tmp += s1.as_str();
}
}
}

self.save_policy_file(tmp)?;
Ok(())
}
Expand Down Expand Up @@ -102,7 +107,7 @@ fn load_policy_line(line: String, m: &mut Model) {
let tokens: Vec<String> = line.split(',').map(|x| x.trim().to_string()).collect();
let key = tokens[0].clone();

if let Some(sec) = key.chars().nth(0).map(|x| x.to_string()) {
if let Some(sec) = key.chars().next().map(|x| x.to_string()) {
if let Some(t1) = m.model.get_mut(&sec) {
if let Some(t2) = t1.get_mut(&key) {
t2.policy.push(tokens[1..].to_vec());
Expand Down
21 changes: 12 additions & 9 deletions src/enforcer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::Result;

use rhai::{Any, Engine, RegisterFn, Scope};

use std::sync::{Arc, RwLock};

pub trait MatchFnClone: Fn(Vec<Box<dyn Any>>) -> bool {
fn clone_box(&self) -> Box<dyn MatchFnClone>;
}
Expand All @@ -27,19 +29,19 @@ impl Clone for Box<dyn MatchFnClone> {
}
}

pub fn generate_g_function(rm: Box<dyn RoleManager>) -> Box<dyn MatchFnClone> {
pub fn generate_g_function(rm: Arc<RwLock<dyn RoleManager>>) -> Box<dyn MatchFnClone> {
let cb = move |args: Vec<Box<dyn Any>>| -> bool {
let args = args
.into_iter()
.filter_map(|x| x.downcast_ref::<String>().map(|y| y.to_owned()))
.collect::<Vec<String>>();

if args.len() == 3 {
let mut rm = rm.clone();
rm.has_link(&args[0], &args[1], Some(&args[2]))
rm.write()
.unwrap()
.has_link(&args[0], &args[1], Some(&args[2]))
} else if args.len() == 2 {
let mut rm = rm.clone();
rm.has_link(&args[0], &args[1], None)
rm.write().unwrap().has_link(&args[0], &args[1], None)
} else {
unreachable!()
}
Expand All @@ -53,7 +55,7 @@ pub struct Enforcer<A: Adapter> {
pub(crate) adapter: A,
pub(crate) fm: FunctionMap,
pub(crate) eft: Box<dyn Effector>,
pub(crate) rm: Box<dyn RoleManager>,
pub(crate) rm: Arc<RwLock<dyn RoleManager>>,
pub(crate) auto_save: bool,
pub(crate) auto_build_role_links: bool,
}
Expand All @@ -64,7 +66,7 @@ impl<A: Adapter> Enforcer<A> {
let m = m;
let fm = FunctionMap::default();
let eft = Box::new(DefaultEffector::default());
let rm = Box::new(DefaultRoleManager::new(10));
let rm = Arc::new(RwLock::new(DefaultRoleManager::new(10)));

let mut e = Self {
model: m,
Expand All @@ -75,6 +77,7 @@ impl<A: Adapter> Enforcer<A> {
auto_save: true,
auto_build_role_links: true,
};

// TODO: check filtered adapter, match over a implementor?
e.load_policy().unwrap();
e
Expand Down Expand Up @@ -243,8 +246,8 @@ impl<A: Adapter> Enforcer<A> {
}

pub fn build_role_links(&mut self) -> Result<()> {
self.rm.clear();
self.model.build_role_links(&mut self.rm)?;
self.rm.write().unwrap().clear();
self.model.build_role_links(Arc::clone(&self.rm))?;
Ok(())
}

Expand Down
52 changes: 35 additions & 17 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::collections::HashMap;
use std::convert::AsRef;
use std::net::IpAddr;
use std::path::Path;
use std::sync::{Arc, RwLock};

fn escape_assertion(s: String) -> String {
let re = Regex::new(r#"(r|p)\."#).unwrap();
Expand All @@ -37,23 +38,31 @@ pub struct Assertion {
pub(crate) value: String,
pub(crate) tokens: Vec<String>,
pub(crate) policy: Vec<Vec<String>>,
pub(crate) rm: Box<dyn RoleManager>,
pub(crate) rm: Arc<RwLock<dyn RoleManager>>,
}

impl Assertion {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
impl Default for Assertion {
fn default() -> Self {
Assertion {
key: String::new(),
value: String::new(),
tokens: vec![],
policy: vec![],
rm: Box::new(DefaultRoleManager::new(0)),
rm: Arc::new(RwLock::new(DefaultRoleManager::new(0))),
}
}
}

impl Assertion {
pub fn get_policy(&self) -> &Vec<Vec<String>> {
&self.policy
}

#[allow(clippy::borrowed_box)]
pub fn build_role_links(&mut self, rm: &mut Box<dyn RoleManager>) -> Result<()> {
pub fn get_mut_policy(&mut self) -> &mut Vec<Vec<String>> {
&mut self.policy
}

pub fn build_role_links(&mut self, rm: Arc<RwLock<dyn RoleManager>>) -> Result<()> {
let count = self.value.chars().filter(|&c| c == '_').count();
for rule in &self.policy {
if count < 2 {
Expand All @@ -66,17 +75,19 @@ impl Assertion {
return Err(Error::PolicyError(PolicyError::UnmatchPolicyDefinition).into());
}
if count == 2 {
rm.add_link(&rule[0], &rule[1], None);
rm.write().unwrap().add_link(&rule[0], &rule[1], None);
} else if count == 3 {
rm.add_link(&rule[0], &rule[1], Some(&rule[2]));
rm.write()
.unwrap()
.add_link(&rule[0], &rule[1], Some(&rule[2]));
} else if count >= 4 {
return Err(Error::ModelError(ModelError::P(
"Multiple domains are not supported".to_owned(),
))
.into());
}
}
self.rm = rm.clone();
self.rm = Arc::clone(&rm);
// self.rm.print_roles();
Ok(())
}
Expand Down Expand Up @@ -119,7 +130,7 @@ impl Model {
}

pub fn add_def(&mut self, sec: &str, key: &str, value: &str) -> bool {
let mut ast = Assertion::new();
let mut ast = Assertion::default();
ast.key = key.to_owned();
ast.value = value.to_owned();

Expand Down Expand Up @@ -154,15 +165,23 @@ impl Model {
let mut i = 1;

loop {
if !self.load_assersion(cfg, sec, &format!("{}{}", sec, self.get_key_suffix(i)))? {
if !self.load_assertion(cfg, sec, &format!("{}{}", sec, self.get_key_suffix(i)))? {
break Ok(());
} else {
i += 1;
}
}
}

fn load_assersion(&mut self, cfg: &Config, sec: &str, key: &str) -> Result<bool> {
pub fn get_model(&self) -> &HashMap<String, AssertionMap> {
&self.model
}

pub fn get_mut_model(&mut self) -> &mut HashMap<String, AssertionMap> {
&mut self.model
}

fn load_assertion(&mut self, cfg: &Config, sec: &str, key: &str) -> Result<bool> {
let sec_name = match sec {
"r" => "request_definition",
"p" => "policy_definition",
Expand All @@ -189,11 +208,10 @@ impl Model {
}
}

#[allow(clippy::borrowed_box)]
pub fn build_role_links(&mut self, rm: &mut Box<dyn RoleManager>) -> Result<()> {
pub fn build_role_links(&mut self, rm: Arc<RwLock<dyn RoleManager>>) -> Result<()> {
if let Some(asts) = self.model.get_mut("g") {
for (_key, ast) in asts.iter_mut() {
ast.build_role_links(rm)?;
for ast in asts.values_mut() {
ast.build_role_links(Arc::clone(&rm))?;
}
}
Ok(())
Expand Down
19 changes: 11 additions & 8 deletions src/rbac/default_role_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@ use crate::Result;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

type MatchingFunc = fn(&str, &str) -> bool;

#[derive(Clone)]
pub struct DefaultRoleManager {
pub all_roles: Arc<RwLock<HashMap<String, Arc<RwLock<Role>>>>>,
pub max_hierarchy_level: usize,
pub has_pattern: bool,
pub matching_func: Option<MatchingFunc>,
all_roles: Arc<RwLock<HashMap<String, Arc<RwLock<Role>>>>>,
max_hierarchy_level: usize,
}

impl Default for DefaultRoleManager {
fn default() -> Self {
DefaultRoleManager {
all_roles: Arc::new(RwLock::new(HashMap::new())),
max_hierarchy_level: 0,
}
}
}

impl DefaultRoleManager {
pub fn new(max_hierarchy_level: usize) -> Self {
DefaultRoleManager {
all_roles: Arc::new(RwLock::new(HashMap::new())),
max_hierarchy_level,
has_pattern: false,
matching_func: None,
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/rbac_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<A: Adapter> RbacApi for Enforcer<A> {
let mut roles = vec![];
if let Some(t1) = self.model.model.get_mut("g") {
if let Some(t2) = t1.get_mut("g") {
roles = t2.rm.get_roles(name, None);
roles = t2.rm.write().unwrap().get_roles(name, None);
}
}

Expand All @@ -62,7 +62,7 @@ impl<A: Adapter> RbacApi for Enforcer<A> {
fn get_users_for_role(&self, name: &str) -> Vec<String> {
if let Some(t1) = self.model.model.get("g") {
if let Some(t2) = t1.get("g") {
return t2.rm.get_users(name, None);
return t2.rm.read().unwrap().get_users(name, None);
}
}
return vec![];
Expand Down Expand Up @@ -125,7 +125,7 @@ impl<A: Adapter> RbacApi for Enforcer<A> {
let name1 = q[0].clone();
name = &name1;
q = q[1..].to_vec();
let roles = self.rm.get_roles(name, domain);
let roles = self.rm.write().unwrap().get_roles(name, domain);
for r in roles.iter().cloned() {
if !role_set.contains_key(&r) {
q.push(r.clone());
Expand Down