Unverified Commit d33f9714 authored by Zixuan Liu's avatar Zixuan Liu Committed by GitHub

Merge pull request #8 from troyanov/feature/custom-table-name-support

Feature/custom table name support
parents d951292b 2eea4794
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"github.com/casbin/casbin/v2/persist" "github.com/casbin/casbin/v2/persist"
"github.com/go-pg/pg/v9" "github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v9/orm" "github.com/go-pg/pg/v9/orm"
"github.com/go-pg/pg/v9/types"
"github.com/mmcloughlin/meow" "github.com/mmcloughlin/meow"
) )
...@@ -30,10 +31,13 @@ type Filter struct { ...@@ -30,10 +31,13 @@ type Filter struct {
// Adapter represents the github.com/go-pg/pg adapter for policy storage. // Adapter represents the github.com/go-pg/pg adapter for policy storage.
type Adapter struct { type Adapter struct {
db *pg.DB db *pg.DB
filtered bool tableName string
filtered bool
} }
type Option func(a *Adapter)
// NewAdapter is the constructor for Adapter. // NewAdapter is the constructor for Adapter.
// arg should be a PostgreS URL string or of type *pg.Options // arg should be a PostgreS URL string or of type *pg.Options
// The adapter will create a DB named "casbin" if it doesn't exist // The adapter will create a DB named "casbin" if it doesn't exist
...@@ -54,14 +58,31 @@ func NewAdapter(arg interface{}) (*Adapter, error) { ...@@ -54,14 +58,31 @@ func NewAdapter(arg interface{}) (*Adapter, error) {
// NewAdapterByDB creates new Adapter by using existing DB connection // NewAdapterByDB creates new Adapter by using existing DB connection
// creates table from CasbinRule struct if it doesn't exist // creates table from CasbinRule struct if it doesn't exist
func NewAdapterByDB(db *pg.DB) (*Adapter, error) { func NewAdapterByDB(db *pg.DB, opts ...Option) (*Adapter, error) {
a := &Adapter{db: db} a := &Adapter{db: db}
for _, opt := range opts {
opt(a)
}
if len(a.tableName) > 0 {
a.db.Model((*CasbinRule)(nil)).TableModel().Table().Name = a.tableName
a.db.Model((*CasbinRule)(nil)).TableModel().Table().FullName = (types.Safe)(a.tableName)
a.db.Model((*CasbinRule)(nil)).TableModel().Table().FullNameForSelects = (types.Safe)(a.tableName)
}
if err := a.createTableifNotExists(); err != nil { if err := a.createTableifNotExists(); err != nil {
return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err) return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
} }
return a, nil return a, nil
} }
// WithTableName can be used to pass custom table name for Casbin rules
func WithTableName(tableName string) Option {
return func(a *Adapter) {
a.tableName = tableName
}
}
func createCasbinDatabase(arg interface{}) (*pg.DB, error) { func createCasbinDatabase(arg interface{}) (*pg.DB, error) {
var opts *pg.Options var opts *pg.Options
var err error var err error
...@@ -145,7 +166,7 @@ func (r *CasbinRule) String() string { ...@@ -145,7 +166,7 @@ func (r *CasbinRule) String() string {
func (a *Adapter) LoadPolicy(model model.Model) error { func (a *Adapter) LoadPolicy(model model.Model) error {
var lines []*CasbinRule var lines []*CasbinRule
if _, err := a.db.Query(&lines, `SELECT * FROM casbin_rules`); err != nil { if err := a.db.Model(&lines).Select(); err != nil {
return err return err
} }
...@@ -236,18 +257,18 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { ...@@ -236,18 +257,18 @@ func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
// AddPolicies adds policy rules to the storage. // AddPolicies adds policy rules to the storage.
func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error { func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error {
var lines []*CasbinRule var lines []*CasbinRule
for _,rule := range rules{ for _, rule := range rules {
line := savePolicyLine(ptype, rule) line := savePolicyLine(ptype, rule)
lines = append(lines, line) lines = append(lines, line)
} }
err := a.db.RunInTransaction(func(tx *pg.Tx) error { err := a.db.RunInTransaction(func(tx *pg.Tx) error {
_, err := tx.Model(&lines). _, err := tx.Model(&lines).
OnConflict("DO NOTHING"). OnConflict("DO NOTHING").
Insert() Insert()
return err return err
}) })
return err return err
} }
...@@ -261,17 +282,17 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { ...@@ -261,17 +282,17 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
// RemovePolicies removes policy rules from the storage. // RemovePolicies removes policy rules from the storage.
func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
var lines []*CasbinRule var lines []*CasbinRule
for _,rule := range rules{ for _, rule := range rules {
line := savePolicyLine(ptype, rule) line := savePolicyLine(ptype, rule)
lines = append(lines, line) lines = append(lines, line)
} }
err := a.db.RunInTransaction(func(tx *pg.Tx) error { err := a.db.RunInTransaction(func(tx *pg.Tx) error {
_, err := tx.Model(&lines). _, err := tx.Model(&lines).
Delete() Delete()
return err return err
}) })
return err return err
} }
......
...@@ -120,17 +120,17 @@ func (s *AdapterTestSuite) TestAutoSave() { ...@@ -120,17 +120,17 @@ func (s *AdapterTestSuite) TestAutoSave() {
// The policy has a new rule: {"alice", "data1", "write"}. // The policy has a new rule: {"alice", "data1", "write"}.
s.assertPolicy( s.assertPolicy(
[][]string{ [][]string{
{"alice", "data1", "read"}, {"alice", "data1", "read"},
{"bob", "data2", "write"}, {"bob", "data2", "write"},
{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "read"},
{"data2_admin", "data2", "write"}, {"data2_admin", "data2", "write"},
{"alice", "data1", "write"}, {"alice", "data1", "write"},
{"bob", "data2", "read"}, {"bob", "data2", "read"},
{"alice", "data2", "write"}, {"alice", "data2", "write"},
{"alice", "data2", "read"}, {"alice", "data2", "read"},
{"bob", "data1", "write"}, {"bob", "data1", "write"},
{"bob", "data1", "read"}, {"bob", "data1", "read"},
}, },
s.e.GetPolicy(), s.e.GetPolicy(),
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment