Commit e91692c9 authored by khoipham's avatar khoipham

implement FilteredAdapter interface

parent d413a3f8
......@@ -8,7 +8,7 @@ import (
"github.com/casbin/casbin/v2/persist"
"github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v9/orm"
"golang.org/x/crypto/sha3"
"github.com/mmcloughlin/meow"
)
const (
......@@ -27,9 +27,15 @@ type CasbinRule struct {
V5 string
}
type Filter struct {
P []string
G []string
}
// Adapter represents the github.com/go-pg/pg adapter for policy storage.
type Adapter struct {
db *pg.DB
filtered bool
}
// NewAdapter is the constructor for Adapter.
......@@ -98,37 +104,37 @@ func (a *Adapter) createTable() error {
return nil
}
func loadPolicyLine(line *CasbinRule, model model.Model) {
func (r *CasbinRule) String() string {
const prefixLine = ", "
var sb strings.Builder
sb.WriteString(line.PType)
if len(line.V0) > 0 {
sb.WriteString(r.PType)
if len(r.V0) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V0)
sb.WriteString(r.V0)
}
if len(line.V1) > 0 {
if len(r.V1) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V1)
sb.WriteString(r.V1)
}
if len(line.V2) > 0 {
if len(r.V2) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V2)
sb.WriteString(r.V2)
}
if len(line.V3) > 0 {
if len(r.V3) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V3)
sb.WriteString(r.V3)
}
if len(line.V4) > 0 {
if len(r.V4) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V4)
sb.WriteString(r.V4)
}
if len(line.V5) > 0 {
if len(r.V5) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V5)
sb.WriteString(r.V5)
}
persist.LoadPolicyLine(sb.String(), model)
return sb.String()
}
// LoadPolicy loads policy from database.
......@@ -140,16 +146,17 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
}
for _, line := range lines {
loadPolicyLine(line, model)
persist.LoadPolicyLine(line.String(), model)
}
a.filtered = false
return nil
}
func policyID(ptype string, rule []string) string {
data := strings.Join(append([]string{ptype}, rule...), ",")
sum := make([]byte, 64)
sha3.ShakeSum128(sum, []byte(data))
sum := meow.Checksum(0, []byte(data))
return fmt.Sprintf("%x", sum)
}
......@@ -248,3 +255,87 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
_, err := query.Delete()
return err
}
func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
if filter == nil {
return a.LoadPolicy(model)
}
filterValue, ok := filter.(*Filter)
if !ok {
return fmt.Errorf("invalid filter type")
}
err := a.loadFilteredPolicy(model, filterValue, persist.LoadPolicyLine)
if err != nil {
return err
}
a.filtered = true
return nil
}
func buildQuery(query *orm.Query, values []string) (*orm.Query, error) {
for ind, v := range values {
if v == "" {
continue
}
switch ind {
case 0:
query = query.Where("v0 = ?", v)
case 1:
query = query.Where("v1 = ?", v)
case 2:
query = query.Where("v2 = ?", v)
case 3:
query = query.Where("v3 = ?", v)
case 4:
query = query.Where("v4 = ?", v)
case 5:
query = query.Where("v5 = ?", v)
default:
return nil, fmt.Errorf("filter has more values than expected, should not exceed 6 values")
}
}
return query, nil
}
func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler func(string, model.Model)) error {
if filter.P != nil {
lines := []*CasbinRule{}
query := a.db.Model(&lines).Where("p_type = 'p'")
query, err := buildQuery(query, filter.P)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}
for _, line := range lines {
handler(line.String(), model)
}
}
if filter.G != nil {
lines := []*CasbinRule{}
query := a.db.Model(&lines).Where("p_type = 'g'")
query, err := buildQuery(query, filter.G)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}
for _, line := range lines {
handler(line.String(), model)
}
}
return nil
}
func (a *Adapter) IsFiltered() bool {
return a.filtered
}
......@@ -17,10 +17,9 @@ type AdapterTestSuite struct {
a *Adapter
}
func (s *AdapterTestSuite) testGetPolicy(res [][]string) {
func (s *AdapterTestSuite) assertPolicy(expected, res [][]string) {
s.T().Helper()
myRes := s.e.GetPolicy()
s.Assert().True(util.Array2DEquals(res, myRes), "Policy Got: %v, supposed to be %v", myRes, res)
s.Assert().True(util.Array2DEquals(expected, res), "Policy Got: %v, supposed to be %v", res, expected)
}
func (s *AdapterTestSuite) dropCasbinDB() {
......@@ -56,7 +55,8 @@ func (s *AdapterTestSuite) TearDownTest() {
}
func (s *AdapterTestSuite) TestSaveLoad() {
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.Assert().False(s.e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}
func (s *AdapterTestSuite) TestAutoSave() {
......@@ -72,7 +72,7 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy()
s.Require().NoError(err)
// This is still the original policy.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
// Now we enable the AutoSave.
s.e.EnableAutoSave(true)
......@@ -85,14 +85,14 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy()
s.Require().NoError(err)
// The policy has a new rule: {"alice", "data1", "write"}.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())
// Aditional AddPolicy have no effect
_, err = s.e.AddPolicy("alice", "data1", "write")
s.Require().NoError(err)
err = s.e.LoadPolicy()
s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())
s.Require().NoError(err)
}
......@@ -107,31 +107,76 @@ func (s *AdapterTestSuite) TestConstructorOptions() {
s.e, err = casbin.NewEnforcer("examples/rbac_model.conf", a)
s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}
func (s *AdapterTestSuite) TestRemovePolicy() {
_, err := s.e.RemovePolicy("alice", "data1", "read")
s.Require().NoError(err)
s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
err = s.e.LoadPolicy()
s.Require().NoError(err)
s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}
func (s *AdapterTestSuite) TestRemoveFilteredPolicy() {
_, err := s.e.RemoveFilteredPolicy(0, "", "data2")
s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())
err = s.e.LoadPolicy()
s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
P: []string{"", "", "read"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"data2_admin", "data2", "read"}}, e.GetPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredGroupingPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
G: []string{"bob"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{}, e.GetGroupingPolicy())
e, err = casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(&Filter{
G: []string{"alice"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data2_admin"}}, e.GetGroupingPolicy())
}
func (s *AdapterTestSuite) TestLoadFilteredPolicyNilFilter() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)
err = e.LoadFilteredPolicy(nil)
s.Require().NoError(err)
s.Assert().False(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}
func TestAdapterTestSuite(t *testing.T) {
......
......@@ -6,6 +6,7 @@ require (
github.com/casbin/casbin v1.9.1
github.com/casbin/casbin/v2 v2.1.2
github.com/go-pg/pg/v9 v9.1.0
github.com/mmcloughlin/meow v0.0.0-20181112033425-871e50784daf
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment