Commit e91692c9 authored by khoipham's avatar khoipham

implement FilteredAdapter interface

parent d413a3f8
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"github.com/casbin/casbin/v2/persist" "github.com/casbin/casbin/v2/persist"
"github.com/go-pg/pg/v9" "github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v9/orm" "github.com/go-pg/pg/v9/orm"
"golang.org/x/crypto/sha3" "github.com/mmcloughlin/meow"
) )
const ( const (
...@@ -27,9 +27,15 @@ type CasbinRule struct { ...@@ -27,9 +27,15 @@ type CasbinRule struct {
V5 string V5 string
} }
type Filter struct {
P []string
G []string
}
// Adapter represents the github.com/go-pg/pg adapter for policy storage. // Adapter represents the github.com/go-pg/pg adapter for policy storage.
type Adapter struct { type Adapter struct {
db *pg.DB db *pg.DB
filtered bool
} }
// NewAdapter is the constructor for Adapter. // NewAdapter is the constructor for Adapter.
...@@ -98,37 +104,37 @@ func (a *Adapter) createTable() error { ...@@ -98,37 +104,37 @@ func (a *Adapter) createTable() error {
return nil return nil
} }
func loadPolicyLine(line *CasbinRule, model model.Model) { func (r *CasbinRule) String() string {
const prefixLine = ", " const prefixLine = ", "
var sb strings.Builder var sb strings.Builder
sb.WriteString(line.PType) sb.WriteString(r.PType)
if len(line.V0) > 0 { if len(r.V0) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V0) sb.WriteString(r.V0)
} }
if len(line.V1) > 0 { if len(r.V1) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V1) sb.WriteString(r.V1)
} }
if len(line.V2) > 0 { if len(r.V2) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V2) sb.WriteString(r.V2)
} }
if len(line.V3) > 0 { if len(r.V3) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V3) sb.WriteString(r.V3)
} }
if len(line.V4) > 0 { if len(r.V4) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V4) sb.WriteString(r.V4)
} }
if len(line.V5) > 0 { if len(r.V5) > 0 {
sb.WriteString(prefixLine) sb.WriteString(prefixLine)
sb.WriteString(line.V5) sb.WriteString(r.V5)
} }
persist.LoadPolicyLine(sb.String(), model) return sb.String()
} }
// LoadPolicy loads policy from database. // LoadPolicy loads policy from database.
...@@ -140,16 +146,17 @@ func (a *Adapter) LoadPolicy(model model.Model) error { ...@@ -140,16 +146,17 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
} }
for _, line := range lines { for _, line := range lines {
loadPolicyLine(line, model) persist.LoadPolicyLine(line.String(), model)
} }
a.filtered = false
return nil return nil
} }
func policyID(ptype string, rule []string) string { func policyID(ptype string, rule []string) string {
data := strings.Join(append([]string{ptype}, rule...), ",") data := strings.Join(append([]string{ptype}, rule...), ",")
sum := make([]byte, 64) sum := meow.Checksum(0, []byte(data))
sha3.ShakeSum128(sum, []byte(data))
return fmt.Sprintf("%x", sum) return fmt.Sprintf("%x", sum)
} }
...@@ -248,3 +255,87 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, ...@@ -248,3 +255,87 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
_, err := query.Delete() _, err := query.Delete()
return err return err
} }
func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
if filter == nil {
return a.LoadPolicy(model)
}
filterValue, ok := filter.(*Filter)
if !ok {
return fmt.Errorf("invalid filter type")
}
err := a.loadFilteredPolicy(model, filterValue, persist.LoadPolicyLine)
if err != nil {
return err
}
a.filtered = true
return nil
}
func buildQuery(query *orm.Query, values []string) (*orm.Query, error) {
for ind, v := range values {
if v == "" {
continue
}
switch ind {
case 0:
query = query.Where("v0 = ?", v)
case 1:
query = query.Where("v1 = ?", v)
case 2:
query = query.Where("v2 = ?", v)
case 3:
query = query.Where("v3 = ?", v)
case 4:
query = query.Where("v4 = ?", v)
case 5:
query = query.Where("v5 = ?", v)
default:
return nil, fmt.Errorf("filter has more values than expected, should not exceed 6 values")
}
}
return query, nil
}
func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler func(string, model.Model)) error {
if filter.P != nil {
lines := []*CasbinRule{}
query := a.db.Model(&lines).Where("p_type = 'p'")
query, err := buildQuery(query, filter.P)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}
for _, line := range lines {
handler(line.String(), model)
}
}
if filter.G != nil {
lines := []*CasbinRule{}
query := a.db.Model(&lines).Where("p_type = 'g'")
query, err := buildQuery(query, filter.G)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}
for _, line := range lines {
handler(line.String(), model)
}
}
return nil
}
func (a *Adapter) IsFiltered() bool {
return a.filtered
}
...@@ -17,10 +17,9 @@ type AdapterTestSuite struct { ...@@ -17,10 +17,9 @@ type AdapterTestSuite struct {
a *Adapter a *Adapter
} }
func (s *AdapterTestSuite) testGetPolicy(res [][]string) { func (s *AdapterTestSuite) assertPolicy(expected, res [][]string) {
s.T().Helper() s.T().Helper()
myRes := s.e.GetPolicy() s.Assert().True(util.Array2DEquals(expected, res), "Policy Got: %v, supposed to be %v", res, expected)
s.Assert().True(util.Array2DEquals(res, myRes), "Policy Got: %v, supposed to be %v", myRes, res)
} }
func (s *AdapterTestSuite) dropCasbinDB() { func (s *AdapterTestSuite) dropCasbinDB() {
...@@ -56,7 +55,8 @@ func (s *AdapterTestSuite) TearDownTest() { ...@@ -56,7 +55,8 @@ func (s *AdapterTestSuite) TearDownTest() {
} }
func (s *AdapterTestSuite) TestSaveLoad() { func (s *AdapterTestSuite) TestSaveLoad() {
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) s.Assert().False(s.e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
} }
func (s *AdapterTestSuite) TestAutoSave() { func (s *AdapterTestSuite) TestAutoSave() {
...@@ -72,7 +72,7 @@ func (s *AdapterTestSuite) TestAutoSave() { ...@@ -72,7 +72,7 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy() err = s.e.LoadPolicy()
s.Require().NoError(err) s.Require().NoError(err)
// This is still the original policy. // This is still the original policy.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
// Now we enable the AutoSave. // Now we enable the AutoSave.
s.e.EnableAutoSave(true) s.e.EnableAutoSave(true)
...@@ -85,14 +85,14 @@ func (s *AdapterTestSuite) TestAutoSave() { ...@@ -85,14 +85,14 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy() err = s.e.LoadPolicy()
s.Require().NoError(err) s.Require().NoError(err)
// The policy has a new rule: {"alice", "data1", "write"}. // The policy has a new rule: {"alice", "data1", "write"}.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())
// Aditional AddPolicy have no effect // Aditional AddPolicy have no effect
_, err = s.e.AddPolicy("alice", "data1", "write") _, err = s.e.AddPolicy("alice", "data1", "write")
s.Require().NoError(err) s.Require().NoError(err)
err = s.e.LoadPolicy() err = s.e.LoadPolicy()
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())
s.Require().NoError(err) s.Require().NoError(err)
} }
...@@ -107,31 +107,76 @@ func (s *AdapterTestSuite) TestConstructorOptions() { ...@@ -107,31 +107,76 @@ func (s *AdapterTestSuite) TestConstructorOptions() {
s.e, err = casbin.NewEnforcer("examples/rbac_model.conf", a) s.e, err = casbin.NewEnforcer("examples/rbac_model.conf", a)
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
} }
func (s *AdapterTestSuite) TestRemovePolicy() { func (s *AdapterTestSuite) TestRemovePolicy() {
_, err := s.e.RemovePolicy("alice", "data1", "read") _, err := s.e.RemovePolicy("alice", "data1", "read")
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
err = s.e.LoadPolicy() err = s.e.LoadPolicy()
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
} }
func (s *AdapterTestSuite) TestRemoveFilteredPolicy() { func (s *AdapterTestSuite) TestRemoveFilteredPolicy() {
_, err := s.e.RemoveFilteredPolicy(0, "", "data2") _, err := s.e.RemoveFilteredPolicy(0, "", "data2")
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())
err = s.e.LoadPolicy() err = s.e.LoadPolicy()
s.Require().NoError(err) s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}}) s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
P: []string{"", "", "read"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"data2_admin", "data2", "read"}}, e.GetPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredGroupingPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
G: []string{"bob"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{}, e.GetGroupingPolicy())
e, err = casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
G: []string{"alice"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data2_admin"}}, e.GetGroupingPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredPolicyNilFilter() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(nil)
s.Require().NoError(err)
s.Assert().False(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
} }
func TestAdapterTestSuite(t *testing.T) { func TestAdapterTestSuite(t *testing.T) {
......
...@@ -6,6 +6,7 @@ require ( ...@@ -6,6 +6,7 @@ require (
github.com/casbin/casbin v1.9.1 github.com/casbin/casbin v1.9.1
github.com/casbin/casbin/v2 v2.1.2 github.com/casbin/casbin/v2 v2.1.2
github.com/go-pg/pg/v9 v9.1.0 github.com/go-pg/pg/v9 v9.1.0
github.com/mmcloughlin/meow v0.0.0-20181112033425-871e50784daf
github.com/stretchr/testify v1.4.0 github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment