Refactoring to allow for separation of activerecord-import code and activerecord's connection adapters.
This commit is contained in:
parent
b6c72f8456
commit
948dec0be2
2
lib/activerecord-import.rb
Normal file
2
lib/activerecord-import.rb
Normal file
|
@ -0,0 +1,2 @@
|
|||
require File.join File.dirname(__FILE__), "activerecord-import/base"
|
||||
ActiveRecord::Extensions.load
|
|
@ -1,146 +1,10 @@
|
|||
require "activerecord-import/adapters/abstract_adapter"
|
||||
|
||||
module ActiveRecord # :nodoc:
|
||||
module ConnectionAdapters # :nodoc:
|
||||
class AbstractAdapter # :nodoc:
|
||||
NO_MAX_PACKET = 0
|
||||
QUERY_OVERHEAD = 8 #This was shown to be true for MySQL, but it's not clear where the overhead is from.
|
||||
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{#{sequence_name}.nextval}
|
||||
end
|
||||
|
||||
# +sql+ can be a single string or an array. If it is an array all
|
||||
# elements that are in position >= 1 will be appended to the final SQL.
|
||||
def insert_many( sql, values, *args ) # :nodoc:
|
||||
# the number of inserts default
|
||||
number_of_inserts = 0
|
||||
|
||||
base_sql,post_sql = if sql.is_a?( String )
|
||||
[ sql, '' ]
|
||||
elsif sql.is_a?( Array )
|
||||
[ sql.shift, sql.join( ' ' ) ]
|
||||
end
|
||||
|
||||
sql_size = QUERY_OVERHEAD + base_sql.size + post_sql.size
|
||||
|
||||
# the number of bytes the requested insert statement values will take up
|
||||
values_in_bytes = self.class.sum_sizes( *values )
|
||||
|
||||
# the number of bytes (commas) it will take to comma separate our values
|
||||
comma_separated_bytes = values.size-1
|
||||
|
||||
# the total number of bytes required if this statement is one statement
|
||||
total_bytes = sql_size + values_in_bytes + comma_separated_bytes
|
||||
|
||||
max = max_allowed_packet
|
||||
|
||||
# if we can insert it all as one statement
|
||||
if NO_MAX_PACKET == max or total_bytes < max
|
||||
number_of_inserts += 1
|
||||
sql2insert = base_sql + values.join( ',' ) + post_sql
|
||||
insert( sql2insert, *args )
|
||||
else
|
||||
value_sets = self.class.get_insert_value_sets( values, sql_size, max )
|
||||
value_sets.each do |values|
|
||||
number_of_inserts += 1
|
||||
sql2insert = base_sql + values.join( ',' ) + post_sql
|
||||
insert( sql2insert, *args )
|
||||
end
|
||||
end
|
||||
|
||||
number_of_inserts
|
||||
end
|
||||
|
||||
def pre_sql_statements(options)
|
||||
sql = []
|
||||
sql << options[:pre_sql] if options[:pre_sql]
|
||||
sql << options[:command] if options[:command]
|
||||
sql << "IGNORE" if options[:ignore]
|
||||
|
||||
#add keywords like IGNORE or DELAYED
|
||||
if options[:keywords].is_a?(Array)
|
||||
sql.concat(options[:keywords])
|
||||
elsif options[:keywords]
|
||||
sql << options[:keywords].to_s
|
||||
end
|
||||
|
||||
sql
|
||||
end
|
||||
|
||||
# Synchronizes the passed in ActiveRecord instances with the records in
|
||||
# the database by calling +reload+ on each instance.
|
||||
def after_import_synchronize( instances )
|
||||
instances.each { |e| e.reload }
|
||||
end
|
||||
|
||||
# Returns an array of post SQL statements given the passed in options.
|
||||
def post_sql_statements( table_name, options ) # :nodoc:
|
||||
post_sql_statements = []
|
||||
if options[:on_duplicate_key_update]
|
||||
post_sql_statements << sql_for_on_duplicate_key_update( table_name, options[:on_duplicate_key_update] )
|
||||
end
|
||||
|
||||
#custom user post_sql
|
||||
post_sql_statements << options[:post_sql] if options[:post_sql]
|
||||
|
||||
#with rollup
|
||||
post_sql_statements << rollup_sql if options[:rollup]
|
||||
|
||||
post_sql_statements
|
||||
end
|
||||
|
||||
|
||||
# Generates the INSERT statement used in insert multiple value sets.
|
||||
def multiple_value_sets_insert_sql(table_name, column_names, options) # :nodoc:
|
||||
"INSERT #{options[:ignore] ? 'IGNORE ':''}INTO #{table_name} (#{column_names.join(',')}) VALUES "
|
||||
end
|
||||
|
||||
# Returns SQL the VALUES for an INSERT statement given the passed in +columns+
|
||||
# and +array_of_attributes+.
|
||||
def values_sql_for_column_names_and_attributes( columns, array_of_attributes ) # :nodoc:
|
||||
values = []
|
||||
array_of_attributes.each do |arr|
|
||||
my_values = []
|
||||
arr.each_with_index do |val,j|
|
||||
my_values << quote( val, columns[j] )
|
||||
end
|
||||
values << my_values
|
||||
end
|
||||
values_arr = values.map{ |arr| '(' + arr.join( ',' ) + ')' }
|
||||
end
|
||||
|
||||
# Returns the sum of the sizes of the passed in objects. This should
|
||||
# probably be moved outside this class, but to where?
|
||||
def self.sum_sizes( *objects ) # :nodoc:
|
||||
objects.inject( 0 ){|sum,o| sum += o.size }
|
||||
end
|
||||
|
||||
# Returns the maximum number of bytes that the server will allow
|
||||
# in a single packet
|
||||
def max_allowed_packet
|
||||
NO_MAX_PACKET
|
||||
end
|
||||
|
||||
def self.get_insert_value_sets( values, sql_size, max_bytes ) # :nodoc:
|
||||
value_sets = []
|
||||
arr, current_arr_values_size, current_size = [], 0, 0
|
||||
values.each_with_index do |val,i|
|
||||
comma_bytes = arr.size
|
||||
sql_size_thus_far = sql_size + current_size + val.size + comma_bytes
|
||||
if NO_MAX_PACKET == max_bytes or sql_size_thus_far <= max_bytes
|
||||
current_size += val.size
|
||||
arr << val
|
||||
else
|
||||
value_sets << arr
|
||||
arr = [ val ]
|
||||
current_size = val.size
|
||||
end
|
||||
|
||||
# if we're on the last iteration push whatever we have in arr to value_sets
|
||||
value_sets << arr if i == (values.size-1)
|
||||
end
|
||||
[ *value_sets ]
|
||||
end
|
||||
|
||||
extend ActiveRecord::Extensions::Import::AbstractAdapter::ClassMethods
|
||||
include ActiveRecord::Extensions::Import::AbstractAdapter::InstanceMethods
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
require "active_record/connection_adapters/mysql2_adapter"
|
||||
require "activerecord-import/active_record/adapters/mysql_base"
|
||||
require "activerecord-import/adapters/mysql_adapter"
|
||||
|
||||
class ActiveRecord::ConnectionAdapters::Mysql2Adapter
|
||||
include ActiveRecord::ConnectionAdapters::MysqlBase
|
||||
include ActiveRecord::Extensions::Import::MysqlAdapter::InstanceMethods
|
||||
end
|
|
@ -1,6 +1,6 @@
|
|||
require "active_record/connection_adapters/mysql_adapter"
|
||||
require "activerecord-import/active_record/adapters/mysql_base"
|
||||
require "activerecord-import/adapters/mysql_adapter"
|
||||
|
||||
class ActiveRecord::ConnectionAdapters::MysqlAdapter
|
||||
include ActiveRecord::ConnectionAdapters::MysqlBase
|
||||
include ActiveRecord::Extensions::Import::MysqlAdapter::InstanceMethods
|
||||
end
|
|
@ -1,48 +0,0 @@
|
|||
module ActiveRecord::ConnectionAdapters::MysqlBase
|
||||
def self.included(klass)
|
||||
klass.instance_eval do
|
||||
include ActiveRecord::Extensions::Import::ImportSupport
|
||||
include ActiveRecord::Extensions::Import::OnDuplicateKeyUpdateSupport
|
||||
end
|
||||
end
|
||||
|
||||
# Returns a generated ON DUPLICATE KEY UPDATE statement given the passed
|
||||
# in +args+.
|
||||
def sql_for_on_duplicate_key_update( table_name, *args ) # :nodoc:
|
||||
sql = ' ON DUPLICATE KEY UPDATE '
|
||||
arg = args.first
|
||||
if arg.is_a?( Array )
|
||||
sql << sql_for_on_duplicate_key_update_as_array( table_name, arg )
|
||||
elsif arg.is_a?( Hash )
|
||||
sql << sql_for_on_duplicate_key_update_as_hash( table_name, arg )
|
||||
elsif arg.is_a?( String )
|
||||
sql << arg
|
||||
else
|
||||
raise ArgumentError.new( "Expected Array or Hash" )
|
||||
end
|
||||
sql
|
||||
end
|
||||
|
||||
def sql_for_on_duplicate_key_update_as_array( table_name, arr ) # :nodoc:
|
||||
results = arr.map do |column|
|
||||
qc = quote_column_name( column )
|
||||
"#{table_name}.#{qc}=VALUES(#{qc})"
|
||||
end
|
||||
results.join( ',' )
|
||||
end
|
||||
|
||||
def sql_for_on_duplicate_key_update_as_hash( table_name, hsh ) # :nodoc:
|
||||
sql = ' ON DUPLICATE KEY UPDATE '
|
||||
results = hsh.map do |column1, column2|
|
||||
qc1 = quote_column_name( column1 )
|
||||
qc2 = quote_column_name( column2 )
|
||||
"#{table_name}.#{qc1}=VALUES( #{qc2} )"
|
||||
end
|
||||
results.join( ',')
|
||||
end
|
||||
|
||||
#return true if the statement is a duplicate key record error
|
||||
def duplicate_key_update_error?(exception)# :nodoc:
|
||||
exception.is_a?(ActiveRecord::StatementInvalid) && exception.to_s.include?('Duplicate entry')
|
||||
end
|
||||
end
|
|
@ -1,11 +1,7 @@
|
|||
require "active_record/connection_adapters/postgresql_adapter"
|
||||
require "activerecord-import/adapters/postgresql_adapter"
|
||||
|
||||
module ActiveRecord # :nodoc:
|
||||
module ConnectionAdapters # :nodoc:
|
||||
class PostgreSQLAdapter # :nodoc:
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{nextval('#{sequence_name}')}
|
||||
end
|
||||
end
|
||||
end
|
||||
class ActiveRecord::ConnectionAdapters::PostgreSQLAdapter
|
||||
include ActiveRecord::Extensions::Import::PostgreSQLAdapter::InstanceMethods
|
||||
end
|
||||
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
require "active_record/connection_adapters/sqlite3_adapter"
|
||||
require "activerecord-import/adapters/sqlite3_adapter"
|
||||
|
||||
module ActiveRecord # :nodoc:
|
||||
module ConnectionAdapters # :nodoc:
|
||||
class Sqlite3Adapter # :nodoc:
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{nextval('#{sequence_name}')}
|
||||
end
|
||||
end
|
||||
end
|
||||
class ActiveRecord::ConnectionAdapters::Sqlite3Adapter
|
||||
include ActiveRecord::Extensions::Import::Sqlite3Adapter::InstanceMethods
|
||||
end
|
||||
|
||||
|
|
145
lib/activerecord-import/adapters/abstract_adapter.rb
Normal file
145
lib/activerecord-import/adapters/abstract_adapter.rb
Normal file
|
@ -0,0 +1,145 @@
|
|||
module ActiveRecord::Extensions::Import::AbstractAdapter
|
||||
NO_MAX_PACKET = 0
|
||||
QUERY_OVERHEAD = 8 #This was shown to be true for MySQL, but it's not clear where the overhead is from.
|
||||
|
||||
module ClassMethods
|
||||
# Returns the sum of the sizes of the passed in objects. This should
|
||||
# probably be moved outside this class, but to where?
|
||||
def sum_sizes( *objects ) # :nodoc:
|
||||
objects.inject( 0 ){|sum,o| sum += o.size }
|
||||
end
|
||||
|
||||
def get_insert_value_sets( values, sql_size, max_bytes ) # :nodoc:
|
||||
value_sets = []
|
||||
arr, current_arr_values_size, current_size = [], 0, 0
|
||||
values.each_with_index do |val,i|
|
||||
comma_bytes = arr.size
|
||||
sql_size_thus_far = sql_size + current_size + val.size + comma_bytes
|
||||
if NO_MAX_PACKET == max_bytes or sql_size_thus_far <= max_bytes
|
||||
current_size += val.size
|
||||
arr << val
|
||||
else
|
||||
value_sets << arr
|
||||
arr = [ val ]
|
||||
current_size = val.size
|
||||
end
|
||||
|
||||
# if we're on the last iteration push whatever we have in arr to value_sets
|
||||
value_sets << arr if i == (values.size-1)
|
||||
end
|
||||
[ *value_sets ]
|
||||
end
|
||||
end
|
||||
|
||||
module InstanceMethods
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{#{sequence_name}.nextval}
|
||||
end
|
||||
|
||||
# +sql+ can be a single string or an array. If it is an array all
|
||||
# elements that are in position >= 1 will be appended to the final SQL.
|
||||
def insert_many( sql, values, *args ) # :nodoc:
|
||||
# the number of inserts default
|
||||
number_of_inserts = 0
|
||||
|
||||
base_sql,post_sql = if sql.is_a?( String )
|
||||
[ sql, '' ]
|
||||
elsif sql.is_a?( Array )
|
||||
[ sql.shift, sql.join( ' ' ) ]
|
||||
end
|
||||
|
||||
sql_size = QUERY_OVERHEAD + base_sql.size + post_sql.size
|
||||
|
||||
# the number of bytes the requested insert statement values will take up
|
||||
values_in_bytes = self.class.sum_sizes( *values )
|
||||
|
||||
# the number of bytes (commas) it will take to comma separate our values
|
||||
comma_separated_bytes = values.size-1
|
||||
|
||||
# the total number of bytes required if this statement is one statement
|
||||
total_bytes = sql_size + values_in_bytes + comma_separated_bytes
|
||||
|
||||
max = max_allowed_packet
|
||||
|
||||
# if we can insert it all as one statement
|
||||
if NO_MAX_PACKET == max or total_bytes < max
|
||||
number_of_inserts += 1
|
||||
sql2insert = base_sql + values.join( ',' ) + post_sql
|
||||
insert( sql2insert, *args )
|
||||
else
|
||||
value_sets = self.class.get_insert_value_sets( values, sql_size, max )
|
||||
value_sets.each do |values|
|
||||
number_of_inserts += 1
|
||||
sql2insert = base_sql + values.join( ',' ) + post_sql
|
||||
insert( sql2insert, *args )
|
||||
end
|
||||
end
|
||||
|
||||
number_of_inserts
|
||||
end
|
||||
|
||||
def pre_sql_statements(options)
|
||||
sql = []
|
||||
sql << options[:pre_sql] if options[:pre_sql]
|
||||
sql << options[:command] if options[:command]
|
||||
sql << "IGNORE" if options[:ignore]
|
||||
|
||||
#add keywords like IGNORE or DELAYED
|
||||
if options[:keywords].is_a?(Array)
|
||||
sql.concat(options[:keywords])
|
||||
elsif options[:keywords]
|
||||
sql << options[:keywords].to_s
|
||||
end
|
||||
|
||||
sql
|
||||
end
|
||||
|
||||
# Synchronizes the passed in ActiveRecord instances with the records in
|
||||
# the database by calling +reload+ on each instance.
|
||||
def after_import_synchronize( instances )
|
||||
instances.each { |e| e.reload }
|
||||
end
|
||||
|
||||
# Returns an array of post SQL statements given the passed in options.
|
||||
def post_sql_statements( table_name, options ) # :nodoc:
|
||||
post_sql_statements = []
|
||||
if options[:on_duplicate_key_update]
|
||||
post_sql_statements << sql_for_on_duplicate_key_update( table_name, options[:on_duplicate_key_update] )
|
||||
end
|
||||
|
||||
#custom user post_sql
|
||||
post_sql_statements << options[:post_sql] if options[:post_sql]
|
||||
|
||||
#with rollup
|
||||
post_sql_statements << rollup_sql if options[:rollup]
|
||||
|
||||
post_sql_statements
|
||||
end
|
||||
|
||||
|
||||
# Generates the INSERT statement used in insert multiple value sets.
|
||||
def multiple_value_sets_insert_sql(table_name, column_names, options) # :nodoc:
|
||||
"INSERT #{options[:ignore] ? 'IGNORE ':''}INTO #{table_name} (#{column_names.join(',')}) VALUES "
|
||||
end
|
||||
|
||||
# Returns SQL the VALUES for an INSERT statement given the passed in +columns+
|
||||
# and +array_of_attributes+.
|
||||
def values_sql_for_column_names_and_attributes( columns, array_of_attributes ) # :nodoc:
|
||||
values = []
|
||||
array_of_attributes.each do |arr|
|
||||
my_values = []
|
||||
arr.each_with_index do |val,j|
|
||||
my_values << quote( val, columns[j] )
|
||||
end
|
||||
values << my_values
|
||||
end
|
||||
values_arr = values.map{ |arr| '(' + arr.join( ',' ) + ')' }
|
||||
end
|
||||
|
||||
# Returns the maximum number of bytes that the server will allow
|
||||
# in a single packet
|
||||
def max_allowed_packet
|
||||
NO_MAX_PACKET
|
||||
end
|
||||
end
|
||||
end
|
50
lib/activerecord-import/adapters/mysql_adapter.rb
Normal file
50
lib/activerecord-import/adapters/mysql_adapter.rb
Normal file
|
@ -0,0 +1,50 @@
|
|||
module ActiveRecord::Extensions::Import::MysqlAdapter
|
||||
module InstanceMethods
|
||||
def self.included(klass)
|
||||
klass.instance_eval do
|
||||
include ActiveRecord::Extensions::Import::ImportSupport
|
||||
include ActiveRecord::Extensions::Import::OnDuplicateKeyUpdateSupport
|
||||
end
|
||||
end
|
||||
|
||||
# Returns a generated ON DUPLICATE KEY UPDATE statement given the passed
|
||||
# in +args+.
|
||||
def sql_for_on_duplicate_key_update( table_name, *args ) # :nodoc:
|
||||
sql = ' ON DUPLICATE KEY UPDATE '
|
||||
arg = args.first
|
||||
if arg.is_a?( Array )
|
||||
sql << sql_for_on_duplicate_key_update_as_array( table_name, arg )
|
||||
elsif arg.is_a?( Hash )
|
||||
sql << sql_for_on_duplicate_key_update_as_hash( table_name, arg )
|
||||
elsif arg.is_a?( String )
|
||||
sql << arg
|
||||
else
|
||||
raise ArgumentError.new( "Expected Array or Hash" )
|
||||
end
|
||||
sql
|
||||
end
|
||||
|
||||
def sql_for_on_duplicate_key_update_as_array( table_name, arr ) # :nodoc:
|
||||
results = arr.map do |column|
|
||||
qc = quote_column_name( column )
|
||||
"#{table_name}.#{qc}=VALUES(#{qc})"
|
||||
end
|
||||
results.join( ',' )
|
||||
end
|
||||
|
||||
def sql_for_on_duplicate_key_update_as_hash( table_name, hsh ) # :nodoc:
|
||||
sql = ' ON DUPLICATE KEY UPDATE '
|
||||
results = hsh.map do |column1, column2|
|
||||
qc1 = quote_column_name( column1 )
|
||||
qc2 = quote_column_name( column2 )
|
||||
"#{table_name}.#{qc1}=VALUES( #{qc2} )"
|
||||
end
|
||||
results.join( ',')
|
||||
end
|
||||
|
||||
#return true if the statement is a duplicate key record error
|
||||
def duplicate_key_update_error?(exception)# :nodoc:
|
||||
exception.is_a?(ActiveRecord::StatementInvalid) && exception.to_s.include?('Duplicate entry')
|
||||
end
|
||||
end
|
||||
end
|
7
lib/activerecord-import/adapters/postgresql_adapter.rb
Normal file
7
lib/activerecord-import/adapters/postgresql_adapter.rb
Normal file
|
@ -0,0 +1,7 @@
|
|||
module ActiveRecord::Extensions::Import::PostgreSQLAdapter
|
||||
module InstanceMethods
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{nextval('#{sequence_name}')}
|
||||
end
|
||||
end
|
||||
end
|
7
lib/activerecord-import/adapters/sqlite3_adapter.rb
Normal file
7
lib/activerecord-import/adapters/sqlite3_adapter.rb
Normal file
|
@ -0,0 +1,7 @@
|
|||
module ActiveRecord::Extensions::Import::Sqlite3Adapter
|
||||
module InstanceMethods
|
||||
def next_value_for_sequence(sequence_name)
|
||||
%{nextval('#{sequence_name}')}
|
||||
end
|
||||
end
|
||||
end
|
|
@ -5,10 +5,17 @@ require "active_record/version"
|
|||
module ActiveRecord::Extensions
|
||||
AdapterPath = File.join File.expand_path(File.dirname(__FILE__)), "/active_record/adapters"
|
||||
|
||||
# Loads the import functionality for a specific database adapter
|
||||
def self.require_adapter(adapter)
|
||||
require File.join(AdapterPath,"/abstract_adapter")
|
||||
require File.join(AdapterPath,"/#{adapter}_adapter")
|
||||
end
|
||||
|
||||
# Loads the import functionality for the current ActiveRecord::Base.connection
|
||||
def self.load
|
||||
config = ActiveRecord::Base.connection.instance_variable_get :@config
|
||||
require_adapter config[:adapter]
|
||||
end
|
||||
end
|
||||
|
||||
this_dir = Pathname.new File.dirname(__FILE__)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
require File.join File.dirname(__FILE__), "base"
|
||||
ActiveRecord::Extensions.require_adapter "mysql"
|
||||
require File.expand_path(File.join(File.dirname(__FILE__), "/../activerecord-import"))
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
require File.join File.dirname(__FILE__), "base"
|
||||
ActiveRecord::Extensions.require_adapter "mysql2"
|
||||
require File.expand_path(File.join(File.dirname(__FILE__), "/../activerecord-import"))
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
require File.join File.dirname(__FILE__), "base"
|
||||
ActiveRecord::Extensions.require_adapter "postgresql"
|
||||
require File.expand_path(File.join(File.dirname(__FILE__), "/../activerecord-import"))
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
require File.join File.dirname(__FILE__), "base"
|
||||
ActiveRecord::Extensions.require_adapter "sqlite3"
|
||||
require File.expand_path(File.join(File.dirname(__FILE__), "/../activerecord-import"))
|
||||
|
||||
|
|
67
test/support/active_support/test_case_extensions.rb
Normal file
67
test/support/active_support/test_case_extensions.rb
Normal file
|
@ -0,0 +1,67 @@
|
|||
class ActiveSupport::TestCase
|
||||
include ActiveRecord::TestFixtures
|
||||
self.use_transactional_fixtures = true
|
||||
|
||||
class << self
|
||||
def assertion(name, &block)
|
||||
mc = class << self ; self ; end
|
||||
mc.class_eval do
|
||||
define_method(name) do
|
||||
it(name, &block)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def asssertion_group(name, &block)
|
||||
mc = class << self ; self ; end
|
||||
mc.class_eval do
|
||||
define_method(name, &block)
|
||||
end
|
||||
end
|
||||
|
||||
def macro(name, &block)
|
||||
class_eval do
|
||||
define_method(name, &block)
|
||||
end
|
||||
end
|
||||
|
||||
def describe(description, toplevel=nil, &blk)
|
||||
text = toplevel ? description : "#{name} #{description}"
|
||||
klass = Class.new(self)
|
||||
|
||||
klass.class_eval <<-RUBY_EVAL
|
||||
def self.name
|
||||
"#{text}"
|
||||
end
|
||||
RUBY_EVAL
|
||||
|
||||
# do not inherit test methods from the superclass
|
||||
klass.class_eval do
|
||||
instance_methods.grep(/^test.+/) do |method|
|
||||
undef_method method
|
||||
end
|
||||
end
|
||||
|
||||
klass.instance_eval &blk
|
||||
end
|
||||
alias_method :context, :describe
|
||||
|
||||
def let(name, &blk)
|
||||
values = {}
|
||||
define_method(name) do
|
||||
return values[name] if values.has_key?(name)
|
||||
values[name] = instance_eval(&blk)
|
||||
end
|
||||
end
|
||||
|
||||
def it(description, &blk)
|
||||
define_method("test: #{name} #{description}", &blk)
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
def describe(description, &blk)
|
||||
ActiveSupport::TestCase.describe(description, true, &blk)
|
||||
end
|
||||
|
|
@ -11,98 +11,28 @@ ENV["RAILS_ENV"] = "test"
|
|||
require "bundler"
|
||||
Bundler.setup
|
||||
|
||||
require "logger"
|
||||
require "rails"
|
||||
require "rails/test_help"
|
||||
require "active_record"
|
||||
require "active_record/fixtures"
|
||||
require "active_support/test_case"
|
||||
|
||||
require "delorean"
|
||||
|
||||
require "active_record"
|
||||
require "logger"
|
||||
|
||||
require "ruby-debug"
|
||||
|
||||
# load test helpers
|
||||
class MyApplication < Rails::Application ; end
|
||||
adapter = ENV["ARE_DB"] || "sqlite3"
|
||||
|
||||
# load the library
|
||||
require "activerecord-import/#{adapter}"
|
||||
|
||||
class ActiveSupport::TestCase
|
||||
include ActiveRecord::TestFixtures
|
||||
self.use_transactional_fixtures = true
|
||||
|
||||
class << self
|
||||
def assertion(name, &block)
|
||||
mc = class << self ; self ; end
|
||||
mc.class_eval do
|
||||
define_method(name) do
|
||||
it(name, &block)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def asssertion_group(name, &block)
|
||||
mc = class << self ; self ; end
|
||||
mc.class_eval do
|
||||
define_method(name, &block)
|
||||
end
|
||||
end
|
||||
|
||||
def macro(name, &block)
|
||||
class_eval do
|
||||
define_method(name, &block)
|
||||
end
|
||||
end
|
||||
|
||||
def describe(description, toplevel=nil, &blk)
|
||||
text = toplevel ? description : "#{name} #{description}"
|
||||
klass = Class.new(self)
|
||||
|
||||
klass.class_eval <<-RUBY_EVAL
|
||||
def self.name
|
||||
"#{text}"
|
||||
end
|
||||
RUBY_EVAL
|
||||
|
||||
# do not inherit test methods from the superclass
|
||||
klass.class_eval do
|
||||
instance_methods.grep(/^test.+/) do |method|
|
||||
undef_method method
|
||||
end
|
||||
end
|
||||
|
||||
klass.instance_eval &blk
|
||||
end
|
||||
alias_method :context, :describe
|
||||
|
||||
def let(name, &blk)
|
||||
values = {}
|
||||
define_method(name) do
|
||||
return values[name] if values.has_key?(name)
|
||||
values[name] = instance_eval(&blk)
|
||||
end
|
||||
end
|
||||
|
||||
def it(description, &blk)
|
||||
define_method("test: #{name} #{description}", &blk)
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
def describe(description, &blk)
|
||||
ActiveSupport::TestCase.describe(description, true, &blk)
|
||||
end
|
||||
|
||||
FileUtils.mkdir_p 'log'
|
||||
ActiveRecord::Base.logger = Logger.new("log/test.log")
|
||||
ActiveRecord::Base.logger.level = Logger::DEBUG
|
||||
ActiveRecord::Base.configurations["test"] = YAML.load(test_dir.join("database.yml").open)[adapter]
|
||||
ActiveRecord::Base.establish_connection "test"
|
||||
|
||||
# load the library
|
||||
require "activerecord-import"
|
||||
|
||||
ActiveSupport::Notifications.subscribe(/active_record.sql/) do |event, _, _, _, hsh|
|
||||
ActiveRecord::Base.logger.info hsh[:sql]
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue